From 39dbcb4dfb2a6e09bb2418d16445cd45631f8d24 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Mon, 23 Oct 2017 17:47:22 +0800 Subject: [PATCH] Give option to disable converting from transcription text to ids. --- data_utils/data.py | 13 ++++++++++--- data_utils/featurizer/speech_featurizer.py | 8 +++++--- deploy/demo_server.py | 3 ++- infer.py | 6 +++--- test.py | 6 +++--- tools/tune.py | 6 +++--- 6 files changed, 26 insertions(+), 16 deletions(-) diff --git a/data_utils/data.py b/data_utils/data.py index 71ba2434..edd4047e 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -55,6 +55,10 @@ class DataGenerator(object): :type num_threads: int :param random_seed: Random seed. :type random_seed: int + :param keep_transcription_text: If set to True, transcription text will + be passed forward directly without + converting to index sequence. + :type keep_transcription_text: bool """ def __init__(self, @@ -69,7 +73,8 @@ class DataGenerator(object): specgram_type='linear', use_dB_normalization=True, num_threads=multiprocessing.cpu_count() // 2, - random_seed=0): + random_seed=0, + keep_transcription_text=False): self._max_duration = max_duration self._min_duration = min_duration self._normalizer = FeatureNormalizer(mean_std_filepath) @@ -84,6 +89,7 @@ class DataGenerator(object): use_dB_normalization=use_dB_normalization) self._num_threads = num_threads self._rng = random.Random(random_seed) + self._keep_transcription_text = keep_transcription_text self._epoch = 0 # for caching tar files info self._local_data = local() @@ -107,9 +113,10 @@ class DataGenerator(object): else: speech_segment = SpeechSegment.from_file(filename, transcript) self._augmentation_pipeline.transform_audio(speech_segment) - specgram, text_ids = self._speech_featurizer.featurize(speech_segment) + specgram, transcript_part = self._speech_featurizer.featurize( + speech_segment, self._keep_transcription_text) specgram = self._normalizer.apply(specgram) - return specgram, text_ids + return specgram, transcript_part def batch_reader_creator(self, manifest_path, diff --git a/data_utils/featurizer/speech_featurizer.py b/data_utils/featurizer/speech_featurizer.py index a947588d..4555dc31 100644 --- a/data_utils/featurizer/speech_featurizer.py +++ b/data_utils/featurizer/speech_featurizer.py @@ -60,12 +60,12 @@ class SpeechFeaturizer(object): target_dB=target_dB) self._text_featurizer = TextFeaturizer(vocab_filepath) - def featurize(self, speech_segment): + def featurize(self, speech_segment, keep_transcription_text): """Extract features for speech segment. 1. For audio parts, extract the audio features. - 2. For transcript parts, convert text string to a list of token indices - in char-level. + 2. For transcript parts, keep the original text or convert text string + to a list of token indices in char-level. :param audio_segment: Speech segment to extract features from. :type audio_segment: SpeechSegment @@ -74,6 +74,8 @@ class SpeechFeaturizer(object): :rtype: tuple """ audio_feature = self._audio_featurizer.featurize(speech_segment) + if keep_transcription_text: + return audio_feature, speech_segment.transcript text_ids = self._text_featurizer.featurize(speech_segment.transcript) return audio_feature, text_ids diff --git a/deploy/demo_server.py b/deploy/demo_server.py index b007c751..3e81c0c5 100644 --- a/deploy/demo_server.py +++ b/deploy/demo_server.py @@ -146,7 +146,8 @@ def start_server(): mean_std_filepath=args.mean_std_path, augmentation_config='{}', specgram_type=args.specgram_type, - num_threads=1) + num_threads=1, + keep_transcription_text=True) # prepare ASR model ds2_model = DeepSpeech2Model( vocab_size=data_generator.vocab_size, diff --git a/infer.py b/infer.py index a30d48d6..74524602 100644 --- a/infer.py +++ b/infer.py @@ -68,7 +68,8 @@ def infer(): mean_std_filepath=args.mean_std_path, augmentation_config='{}', specgram_type=args.specgram_type, - num_threads=1) + num_threads=1, + keep_transcription_text=True) batch_reader = data_generator.batch_reader_creator( manifest_path=args.infer_manifest, batch_size=args.num_samples, @@ -103,8 +104,7 @@ def infer(): error_rate_func = cer if args.error_rate_type == 'cer' else wer target_transcripts = [ - ''.join([data_generator.vocab_list[token] for token in transcript]) - for _, transcript in infer_data + transcript for _, transcript in infer_data ] for target, result in zip(target_transcripts, result_transcripts): print("\nTarget Transcription: %s\nOutput Transcription: %s" % diff --git a/test.py b/test.py index 94c09150..5466f960 100644 --- a/test.py +++ b/test.py @@ -69,7 +69,8 @@ def evaluate(): mean_std_filepath=args.mean_std_path, augmentation_config='{}', specgram_type=args.specgram_type, - num_threads=args.num_proc_data) + num_threads=args.num_proc_data, + keep_transcription_text=True) batch_reader = data_generator.batch_reader_creator( manifest_path=args.test_manifest, batch_size=args.batch_size, @@ -104,8 +105,7 @@ def evaluate(): language_model_path=args.lang_model_path, num_processes=args.num_proc_bsearch) target_transcripts = [ - ''.join([data_generator.vocab_list[token] for token in transcript]) - for _, transcript in infer_data + transcript for _, transcript in infer_data ] for target, result in zip(target_transcripts, result_transcripts): error_sum += error_rate_func(target, result) diff --git a/tools/tune.py b/tools/tune.py index 233ec4ab..99ffb5f5 100644 --- a/tools/tune.py +++ b/tools/tune.py @@ -87,7 +87,8 @@ def tune(): mean_std_filepath=args.mean_std_path, augmentation_config='{}', specgram_type=args.specgram_type, - num_threads=args.num_proc_data) + num_threads=args.num_proc_data, + keep_transcription_text=True) audio_data = paddle.layer.data( name="audio_spectrogram", @@ -164,8 +165,7 @@ def tune(): ] target_transcripts = [ - ''.join([data_generator.vocab_list[token] for token in transcript]) - for _, transcript in infer_data + transcript for _, transcript in infer_data ] num_ins += len(target_transcripts)