From ac0ae57ef24412f971216acabbc14abfce0f65e2 Mon Sep 17 00:00:00 2001 From: Junkun Date: Wed, 4 Aug 2021 13:42:03 -0700 Subject: [PATCH] add collactor and evaluation code for ST --- deepspeech/exps/u2_st/model.py | 29 +- deepspeech/io/collator_st.py | 666 ++++++++++++++++++++++++++++++ deepspeech/io/dataset.py | 17 +- deepspeech/models/u2_st.py | 734 +++++++++++++++++++++++++++++++++ deepspeech/utils/bleu_score.py | 53 +++ 5 files changed, 1484 insertions(+), 15 deletions(-) create mode 100644 deepspeech/io/collator_st.py create mode 100644 deepspeech/models/u2_st.py create mode 100644 deepspeech/utils/bleu_score.py diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index 21323fc9..867d1899 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -24,7 +24,6 @@ from typing import Tuple import numpy as np import paddle -import sacrebleu from paddle import distributed as dist from paddle.io import DataLoader from yacs.config import CfgNode @@ -32,6 +31,7 @@ from yacs.config import CfgNode from deepspeech.io.collator_st import KaldiPrePorocessedCollator from deepspeech.io.collator_st import SpeechCollator from deepspeech.io.collator_st import TripletKaldiPrePorocessedCollator +from deepspeech.io.collator_st import TripletSpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import TripletManifestDataset from deepspeech.io.sampler import SortagradBatchSampler @@ -40,6 +40,7 @@ from deepspeech.models.u2_st import U2STModel from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.scheduler import WarmupLR from deepspeech.training.trainer import Trainer +from deepspeech.utils import bleu_score from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate from deepspeech.utils import layer_tools @@ -248,7 +249,11 @@ class U2STTrainer(Trainer): dev_dataset = Dataset.from_config(config) if config.collator.raw_wav: - TestCollator = Collator = SpeechCollator + if config.model.model_conf.asr_weight > 0.: + Collator = TripletSpeechCollator + TestCollator = SpeechCollator + else: + TestCollator = Collator = SpeechCollator # Not yet implement the mtl loader for raw_wav. else: if config.model.model_conf.asr_weight > 0.: @@ -393,7 +398,7 @@ class U2STTester(U2STTrainer): lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. decoding_method='attention', # Decoding method. Options: 'attention', 'ctc_greedy_search', # 'ctc_prefix_beam_search', 'attention_rescoring' - error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + error_rate_type='bleu', # Error rate type for evaluation. Options `bleu`, 'char_bleu' num_proc_bsearch=8, # # of CPUs for beam search. beam_size=10, # Beam search width. batch_size=16, # decoding batch size @@ -428,10 +433,10 @@ class U2STTester(U2STTrainer): audio_len, texts, texts_len, + bleu_func, fout=None): cfg = self.config.decoding len_refs, num_ins = 0, 0 - bleu_func = sacrebleu.corpus_bleu start_time = time.time() text_feature = self.test_loader.collate_fn.text_feature @@ -487,6 +492,9 @@ class U2STTester(U2STTrainer): self.model.eval() logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + cfg = self.config.decoding + bleu_func = bleu_score.char_bleu if cfg.error_rate_type == 'char-bleu' else bleu_score.bleu + stride_ms = self.test_loader.collate_fn.stride_ms hyps, refs = [], [] len_refs, num_ins = 0, 0 @@ -495,7 +503,7 @@ class U2STTester(U2STTrainer): with open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): metrics = self.compute_translation_metrics( - *batch, fout=fout) + *batch, bleu_func=bleu_func, fout=fout) hyps += metrics['hyps'] refs += metrics['refs'] bleu = metrics['bleu'] @@ -504,19 +512,16 @@ class U2STTester(U2STTrainer): len_refs += metrics['len_refs'] num_ins += metrics['num_ins'] rtf = num_time / (num_frames * stride_ms) - logger.info("RTF: %f, BELU (%d) = %f" % - (rtf, num_ins, bleu)) + logger.info("RTF: %f, BELU (%d) = %f" % (rtf, num_ins, bleu)) rtf = num_time / (num_frames * stride_ms) msg = "Test: " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) msg += "RTF: {}, ".format(rtf) - msg += "Test set [%s]: %s" % ( - len(hyps), str(sacrebleu.corpus_bleu(hyps, [refs]))) + msg += "Test set [%s]: %s" % (len(hyps), str(bleu_func(hyps, [refs]))) logger.info(msg) - bleu_meta_path = os.path.splitext( - self.args.result_file)[0] + '.bleu' + bleu_meta_path = os.path.splitext(self.args.result_file)[0] + '.bleu' err_type_str = "BLEU" with open(bleu_meta_path, 'w') as f: data = json.dumps({ @@ -527,7 +532,7 @@ class U2STTester(U2STTrainer): "rtf": rtf, err_type_str: - sacrebleu.corpus_bleu(hyps, [refs]).score, + bleu_func(hyps, [refs]).score, "dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0, "process_hour": num_time / 1000.0 / 3600.0, diff --git a/deepspeech/io/collator_st.py b/deepspeech/io/collator_st.py new file mode 100644 index 00000000..34933312 --- /dev/null +++ b/deepspeech/io/collator_st.py @@ -0,0 +1,666 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +from collections import namedtuple +from typing import Optional +from typing import Tuple + +import kaldiio +import numpy as np +from yacs.config import CfgNode + +from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline +from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer +from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer +from deepspeech.frontend.normalizer import FeatureNormalizer +from deepspeech.frontend.speech import SpeechSegment +from deepspeech.frontend.utility import IGNORE_ID +from deepspeech.io.utility import pad_sequence +from deepspeech.utils.log import Log + +__all__ = ["SpeechCollator", "KaldiPrePorocessedCollator"] + +logger = Log(__name__).getlog() + +# namedtupe need global for pickle. +TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) + + +class SpeechCollator(): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + augmentation_config="", + random_seed=0, + mean_std_filepath="", + unit_type="char", + vocab_filepath="", + spm_model_prefix="", + specgram_type='linear', # 'linear', 'mfcc', 'fbank' + feat_dim=0, # 'mfcc', 'fbank' + delta_delta=False, # 'mfcc', 'fbank' + stride_ms=10.0, # ms + window_ms=20.0, # ms + n_fft=None, # fft points + max_freq=None, # None for samplerate/2 + target_sample_rate=16000, # target sample rate + use_dB_normalization=True, + target_dB=-20, + dither=1.0, # feature dither + keep_transcription_text=False)) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + @classmethod + def from_config(cls, config): + """Build a SpeechCollator object from a config. + + Args: + config (yacs.config.CfgNode): configs object. + + Returns: + SpeechCollator: collator object. + """ + assert 'augmentation_config' in config.collator + assert 'keep_transcription_text' in config.collator + assert 'mean_std_filepath' in config.collator + assert 'vocab_filepath' in config.collator + assert 'specgram_type' in config.collator + assert 'n_fft' in config.collator + assert config.collator + + if isinstance(config.collator.augmentation_config, (str, bytes)): + if config.collator.augmentation_config: + aug_file = io.open( + config.collator.augmentation_config, + mode='r', + encoding='utf8') + else: + aug_file = io.StringIO(initial_value='{}', newline='') + else: + aug_file = config.collator.augmentation_config + assert isinstance(aug_file, io.StringIO) + + speech_collator = cls( + aug_file=aug_file, + random_seed=0, + mean_std_filepath=config.collator.mean_std_filepath, + unit_type=config.collator.unit_type, + vocab_filepath=config.collator.vocab_filepath, + spm_model_prefix=config.collator.spm_model_prefix, + specgram_type=config.collator.specgram_type, + feat_dim=config.collator.feat_dim, + delta_delta=config.collator.delta_delta, + stride_ms=config.collator.stride_ms, + window_ms=config.collator.window_ms, + n_fft=config.collator.n_fft, + max_freq=config.collator.max_freq, + target_sample_rate=config.collator.target_sample_rate, + use_dB_normalization=config.collator.use_dB_normalization, + target_dB=config.collator.target_dB, + dither=config.collator.dither, + keep_transcription_text=config.collator.keep_transcription_text) + return speech_collator + + def __init__( + self, + aug_file, + mean_std_filepath, + vocab_filepath, + spm_model_prefix, + random_seed=0, + unit_type="char", + specgram_type='linear', # 'linear', 'mfcc', 'fbank' + feat_dim=0, # 'mfcc', 'fbank' + delta_delta=False, # 'mfcc', 'fbank' + stride_ms=10.0, # ms + window_ms=20.0, # ms + n_fft=None, # fft points + max_freq=None, # None for samplerate/2 + target_sample_rate=16000, # target sample rate + use_dB_normalization=True, + target_dB=-20, + dither=1.0, + keep_transcription_text=True): + """SpeechCollator Collator + + Args: + unit_type(str): token unit type, e.g. char, word, spm + vocab_filepath (str): vocab file path. + mean_std_filepath (str): mean and std file path, which suffix is *.npy + spm_model_prefix (str): spm model prefix, need if `unit_type` is spm. + augmentation_config (str, optional): augmentation json str. Defaults to '{}'. + stride_ms (float, optional): stride size in ms. Defaults to 10.0. + window_ms (float, optional): window size in ms. Defaults to 20.0. + n_fft (int, optional): fft points for rfft. Defaults to None. + max_freq (int, optional): max cut freq. Defaults to None. + target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000. + specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'. + feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None. + delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False. + use_dB_normalization (bool, optional): do dB normalization. Defaults to True. + target_dB (int, optional): target dB. Defaults to -20. + random_seed (int, optional): for random generator. Defaults to 0. + keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. + if ``keep_transcription_text`` is False, text is token ids else is raw string. + + Do augmentations + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one batch. + """ + self._keep_transcription_text = keep_transcription_text + + self._local_data = TarLocalData(tar2info={}, tar2object={}) + self._augmentation_pipeline = AugmentationPipeline( + augmentation_config=aug_file.read(), random_seed=random_seed) + + self._normalizer = FeatureNormalizer( + mean_std_filepath) if mean_std_filepath else None + + self._stride_ms = stride_ms + self._target_sample_rate = target_sample_rate + + self._speech_featurizer = SpeechFeaturizer( + unit_type=unit_type, + vocab_filepath=vocab_filepath, + spm_model_prefix=spm_model_prefix, + specgram_type=specgram_type, + feat_dim=feat_dim, + delta_delta=delta_delta, + stride_ms=stride_ms, + window_ms=window_ms, + n_fft=n_fft, + max_freq=max_freq, + target_sample_rate=target_sample_rate, + use_dB_normalization=use_dB_normalization, + target_dB=target_dB, + dither=dither) + + def _parse_tar(self, file): + """Parse a tar file to get a tarfile object + and a map containing tarinfoes + """ + result = {} + f = tarfile.open(file) + for tarinfo in f.getmembers(): + result[tarinfo.name] = tarinfo + return f, result + + def _subfile_from_tar(self, file): + """Get subfile object from tar. + + It will return a subfile object from tar file + and cached tar file info for next reading request. + """ + tarpath, filename = file.split(':', 1)[1].split('#', 1) + if 'tar2info' not in self._local_data.__dict__: + self._local_data.tar2info = {} + if 'tar2object' not in self._local_data.__dict__: + self._local_data.tar2object = {} + if tarpath not in self._local_data.tar2info: + object, infoes = self._parse_tar(tarpath) + self._local_data.tar2info[tarpath] = infoes + self._local_data.tar2object[tarpath] = object + return self._local_data.tar2object[tarpath].extractfile( + self._local_data.tar2info[tarpath][filename]) + + def process_utterance(self, audio_file, translation): + """Load, augment, featurize and normalize for speech data. + + :param audio_file: Filepath or file object of audio file. + :type audio_file: str | file + :param translation: translation text. + :type translation: str + :return: Tuple of audio feature tensor and data of translation part, + where translation part could be token ids or text. + :rtype: tuple of (2darray, list) + """ + if isinstance(audio_file, str) and audio_file.startswith('tar:'): + speech_segment = SpeechSegment.from_file( + self._subfile_from_tar(audio_file), translation) + else: + speech_segment = SpeechSegment.from_file(audio_file, translation) + + # audio augment + self._augmentation_pipeline.transform_audio(speech_segment) + + specgram, translation_part = self._speech_featurizer.featurize( + speech_segment, self._keep_transcription_text) + if self._normalizer: + specgram = self._normalizer.apply(specgram) + + # specgram augment + specgram = self._augmentation_pipeline.transform_feature(specgram) + specgram = specgram.transpose([1, 0]) + return specgram, translation_part + + def __call__(self, batch): + """batch examples + + Args: + batch ([List]): batch is (audio, text) + audio (np.ndarray) shape (D, T) + text (List[int] or str): shape (U,) + + Returns: + tuple(audio, text, audio_lens, text_lens): batched data. + audio : (B, Tmax, D) + audio_lens: (B) + text : (B, Umax) + text_lens: (B) + """ + audios = [] + audio_lens = [] + texts = [] + text_lens = [] + utts = [] + for utt, audio, text in batch: + audio, text = self.process_utterance(audio, text) + #utt + utts.append(utt) + # audio + audios.append(audio) # [T, D] + audio_lens.append(audio.shape[0]) + # text + # for training, text is token ids + # else text is string, convert to unicode ord + tokens = [] + if self._keep_transcription_text: + assert isinstance(text, str), (type(text), text) + tokens = [ord(t) for t in text] + else: + tokens = text # token ids + tokens = tokens if isinstance(tokens, np.ndarray) else np.array( + tokens, dtype=np.int64) + texts.append(tokens) + text_lens.append(tokens.shape[0]) + + padded_audios = pad_sequence( + audios, padding_value=0.0).astype(np.float32) #[B, T, D] + audio_lens = np.array(audio_lens).astype(np.int64) + padded_texts = pad_sequence( + texts, padding_value=IGNORE_ID).astype(np.int64) + text_lens = np.array(text_lens).astype(np.int64) + return utts, padded_audios, audio_lens, padded_texts, text_lens + + @property + def manifest(self): + return self._manifest + + @property + def vocab_size(self): + return self._speech_featurizer.vocab_size + + @property + def vocab_list(self): + return self._speech_featurizer.vocab_list + + @property + def vocab_dict(self): + return self._speech_featurizer.vocab_dict + + @property + def text_feature(self): + return self._speech_featurizer.text_feature + + @property + def feature_size(self): + return self._speech_featurizer.feature_size + + @property + def stride_ms(self): + return self._speech_featurizer.stride_ms + + +class TripletSpeechCollator(SpeechCollator): + def process_utterance(self, audio_file, translation, transcript): + """Load, augment, featurize and normalize for speech data. + + :param audio_file: Filepath or file object of audio file. + :type audio_file: str | file + :param translation: translation text. + :type translation: str + :return: Tuple of audio feature tensor and data of translation part, + where translation part could be token ids or text. + :rtype: tuple of (2darray, list) + """ + if isinstance(audio_file, str) and audio_file.startswith('tar:'): + speech_segment = SpeechSegment.from_file( + self._subfile_from_tar(audio_file), translation) + else: + speech_segment = SpeechSegment.from_file(audio_file, translation) + + # audio augment + self._augmentation_pipeline.transform_audio(speech_segment) + + specgram, translation_part = self._speech_featurizer.featurize( + speech_segment, self._keep_transcription_text) + transcript_part = self._speech_featurizer._text_featurizer.featurize( + transcript) + if self._normalizer: + specgram = self._normalizer.apply(specgram) + + # specgram augment + specgram = self._augmentation_pipeline.transform_feature(specgram) + specgram = specgram.transpose([1, 0]) + return specgram, translation_part, transcript_part + + def __call__(self, batch): + """batch examples + + Args: + batch ([List]): batch is (audio, text) + audio (np.ndarray) shape (D, T) + text (List[int] or str): shape (U,) + + Returns: + tuple(audio, text, audio_lens, text_lens): batched data. + audio : (B, Tmax, D) + audio_lens: (B) + text : (B, Umax) + text_lens: (B) + """ + audios = [] + audio_lens = [] + translation_text = [] + translation_text_lens = [] + transcription_text = [] + transcription_text_lens = [] + + utts = [] + for utt, audio, translation, transcription in batch: + audio, translation, transcription = self.process_utterance( + audio, translation, transcription) + #utt + utts.append(utt) + # audio + audios.append(audio) # [T, D] + audio_lens.append(audio.shape[0]) + # text + # for training, text is token ids + # else text is string, convert to unicode ord + tokens = [[], []] + for idx, text in enumerate([translation, transcription]): + if self._keep_transcription_text: + assert isinstance(text, str), (type(text), text) + tokens[idx] = [ord(t) for t in text] + else: + tokens[idx] = text # token ids + tokens[idx] = tokens[idx] if isinstance( + tokens[idx], np.ndarray) else np.array( + tokens[idx], dtype=np.int64) + translation_text.append(tokens[0]) + translation_text_lens.append(tokens[0].shape[0]) + transcription_text.append(tokens[1]) + transcription_text_lens.append(tokens[1].shape[0]) + + padded_audios = pad_sequence( + audios, padding_value=0.0).astype(np.float32) #[B, T, D] + audio_lens = np.array(audio_lens).astype(np.int64) + padded_translation = pad_sequence( + translation_text, padding_value=IGNORE_ID).astype(np.int64) + translation_lens = np.array(translation_text_lens).astype(np.int64) + padded_transcription = pad_sequence( + transcription_text, padding_value=IGNORE_ID).astype(np.int64) + transcription_lens = np.array(transcription_text_lens).astype(np.int64) + return utts, padded_audios, audio_lens, ( + padded_translation, padded_transcription), (translation_lens, + transcription_lens) + + +class KaldiPrePorocessedCollator(SpeechCollator): + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + default = CfgNode( + dict( + augmentation_config="", + random_seed=0, + unit_type="char", + vocab_filepath="", + spm_model_prefix="", + feat_dim=0, + stride_ms=10.0, + keep_transcription_text=False)) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + @classmethod + def from_config(cls, config): + """Build a SpeechCollator object from a config. + + Args: + config (yacs.config.CfgNode): configs object. + + Returns: + SpeechCollator: collator object. + """ + assert 'augmentation_config' in config.collator + assert 'keep_transcription_text' in config.collator + assert 'vocab_filepath' in config.collator + assert config.collator + + if isinstance(config.collator.augmentation_config, (str, bytes)): + if config.collator.augmentation_config: + aug_file = io.open( + config.collator.augmentation_config, + mode='r', + encoding='utf8') + else: + aug_file = io.StringIO(initial_value='{}', newline='') + else: + aug_file = config.collator.augmentation_config + assert isinstance(aug_file, io.StringIO) + + speech_collator = cls( + aug_file=aug_file, + random_seed=0, + unit_type=config.collator.unit_type, + vocab_filepath=config.collator.vocab_filepath, + spm_model_prefix=config.collator.spm_model_prefix, + feat_dim=config.collator.feat_dim, + stride_ms=config.collator.stride_ms, + keep_transcription_text=config.collator.keep_transcription_text) + return speech_collator + + def __init__(self, + aug_file, + vocab_filepath, + spm_model_prefix, + random_seed=0, + unit_type="char", + feat_dim=0, + stride_ms=10.0, + keep_transcription_text=True): + """SpeechCollator Collator + + Args: + unit_type(str): token unit type, e.g. char, word, spm + vocab_filepath (str): vocab file path. + spm_model_prefix (str): spm model prefix, need if `unit_type` is spm. + augmentation_config (str, optional): augmentation json str. Defaults to '{}'. + random_seed (int, optional): for random generator. Defaults to 0. + keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. + if ``keep_transcription_text`` is False, text is token ids else is raw string. + + Do augmentations + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one batch. + """ + self._keep_transcription_text = keep_transcription_text + self._feat_dim = feat_dim + self._stride_ms = stride_ms + + self._local_data = TarLocalData(tar2info={}, tar2object={}) + self._augmentation_pipeline = AugmentationPipeline( + augmentation_config=aug_file.read(), random_seed=random_seed) + + self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath, + spm_model_prefix) + + def process_utterance(self, audio_file, translation): + """Load, augment, featurize and normalize for speech data. + + :param audio_file: Filepath or file object of kaldi processed feature. + :type audio_file: str | file + :param translation: Translation text. + :type translation: str + :return: Tuple of audio feature tensor and data of translation part, + where translation part could be token ids or text. + :rtype: tuple of (2darray, list) + """ + specgram = kaldiio.load_mat(audio_file) + specgram = specgram.transpose([1, 0]) + assert specgram.shape[ + 0] == self._feat_dim, 'expect feat dim {}, but got {}'.format( + self._feat_dim, specgram.shape[0]) + + # specgram augment + specgram = self._augmentation_pipeline.transform_feature(specgram) + + specgram = specgram.transpose([1, 0]) + if self._keep_transcription_text: + return specgram, translation + else: + text_ids = self._text_featurizer.featurize(translation) + return specgram, text_ids + + @property + def manifest(self): + return self._manifest + + @property + def vocab_size(self): + return self._text_featurizer.vocab_size + + @property + def vocab_list(self): + return self._text_featurizer.vocab_list + + @property + def vocab_dict(self): + return self._text_featurizer.vocab_dict + + @property + def text_feature(self): + return self._text_featurizer + + @property + def feature_size(self): + return self._feat_dim + + @property + def stride_ms(self): + return self._stride_ms + + +class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator): + def process_utterance(self, audio_file, translation, transcript): + """Load, augment, featurize and normalize for speech data. + + :param audio_file: Filepath or file object of kali processed feature. + :type audio_file: str | file + :param translation: Translation text. + :type translation: str + :param transcript: Transcription text. + :type transcript: str + :return: Tuple of audio feature tensor and data of translation and transcription parts, + where translation and transcription parts could be token ids or text. + :rtype: tuple of (2darray, (list, list)) + """ + specgram = kaldiio.load_mat(audio_file) + specgram = specgram.transpose([1, 0]) + assert specgram.shape[ + 0] == self._feat_dim, 'expect feat dim {}, but got {}'.format( + self._feat_dim, specgram.shape[0]) + + # specgram augment + specgram = self._augmentation_pipeline.transform_feature(specgram) + + specgram = specgram.transpose([1, 0]) + if self._keep_transcription_text: + return specgram, translation, transcript + else: + translation_text_ids = self._text_featurizer.featurize(translation) + transcript_text_ids = self._text_featurizer.featurize(transcript) + return specgram, translation_text_ids, transcript_text_ids + + def __call__(self, batch): + """batch examples + + Args: + batch ([List]): batch is (audio, text) + audio (np.ndarray) shape (D, T) + translation (List[int] or str): shape (U,) + transcription (List[int] or str): shape (V,) + + Returns: + tuple(audio, text, audio_lens, text_lens): batched data. + audio : (B, Tmax, D) + audio_lens: (B) + translation_text : (B, Umax) + translation_text_lens: (B) + transcription_text : (B, Vmax) + transcription_text_lens: (B) + """ + audios = [] + audio_lens = [] + translation_text = [] + translation_text_lens = [] + transcription_text = [] + transcription_text_lens = [] + + utts = [] + for utt, audio, translation, transcription in batch: + audio, translation, transcription = self.process_utterance( + audio, translation, transcription) + #utt + utts.append(utt) + # audio + audios.append(audio) # [T, D] + audio_lens.append(audio.shape[0]) + # text + # for training, text is token ids + # else text is string, convert to unicode ord + tokens = [[], []] + for idx, text in enumerate([translation, transcription]): + if self._keep_transcription_text: + assert isinstance(text, str), (type(text), text) + tokens[idx] = [ord(t) for t in text] + else: + tokens[idx] = text # token ids + tokens[idx] = tokens[idx] if isinstance( + tokens[idx], np.ndarray) else np.array( + tokens[idx], dtype=np.int64) + translation_text.append(tokens[0]) + translation_text_lens.append(tokens[0].shape[0]) + transcription_text.append(tokens[1]) + transcription_text_lens.append(tokens[1].shape[0]) + + padded_audios = pad_sequence( + audios, padding_value=0.0).astype(np.float32) #[B, T, D] + audio_lens = np.array(audio_lens).astype(np.int64) + padded_translation = pad_sequence( + translation_text, padding_value=IGNORE_ID).astype(np.int64) + translation_lens = np.array(translation_text_lens).astype(np.int64) + padded_transcription = pad_sequence( + transcription_text, padding_value=IGNORE_ID).astype(np.int64) + transcription_lens = np.array(transcription_text_lens).astype(np.int64) + return utts, padded_audios, audio_lens, ( + padded_translation, padded_transcription), (translation_lens, + transcription_lens) diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 3fc4e988..ac7be1f9 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -19,9 +19,7 @@ from yacs.config import CfgNode from deepspeech.frontend.utility import read_manifest from deepspeech.utils.log import Log -__all__ = [ - "ManifestDataset", -] +__all__ = ["ManifestDataset", "TripletManifestDataset"] logger = Log(__name__).getlog() @@ -105,3 +103,16 @@ class ManifestDataset(Dataset): def __getitem__(self, idx): instance = self._manifest[idx] return instance["utt"], instance["feat"], instance["text"] + + +class TripletManifestDataset(ManifestDataset): + """ + For Joint Training of Speech Translation and ASR. + text: translation, + text1: transcript. + """ + + def __getitem__(self, idx): + instance = self._manifest[idx] + return instance["utt"], instance["feat"], instance["text"], instance[ + "text1"] diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st.py new file mode 100644 index 00000000..5eea139b --- /dev/null +++ b/deepspeech/models/u2_st.py @@ -0,0 +1,734 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""U2 ASR Model +Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition +(https://arxiv.org/pdf/2012.05481.pdf) +""" +import sys +import time +from collections import defaultdict +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import paddle +from paddle import jit +from paddle import nn +from yacs.config import CfgNode + +from deepspeech.frontend.utility import IGNORE_ID +from deepspeech.frontend.utility import load_cmvn +from deepspeech.modules.cmvn import GlobalCMVN +from deepspeech.modules.ctc import CTCDecoder +from deepspeech.modules.decoder import TransformerDecoder +from deepspeech.modules.encoder import ConformerEncoder +from deepspeech.modules.encoder import TransformerEncoder +from deepspeech.modules.loss import LabelSmoothingLoss +from deepspeech.modules.mask import make_pad_mask +from deepspeech.modules.mask import mask_finished_preds +from deepspeech.modules.mask import mask_finished_scores +from deepspeech.modules.mask import subsequent_mask +from deepspeech.utils import checkpoint +from deepspeech.utils import layer_tools +from deepspeech.utils.ctc_utils import remove_duplicates_and_blank +from deepspeech.utils.log import Log +from deepspeech.utils.tensor_utils import add_sos_eos +from deepspeech.utils.tensor_utils import pad_sequence +from deepspeech.utils.tensor_utils import th_accuracy +from deepspeech.utils.utility import log_add + +__all__ = ["U2STModel", "U2STInferModel"] + +logger = Log(__name__).getlog() + + +class U2STBaseModel(nn.Module): + """CTC-Attention hybrid Encoder-Decoder model""" + + @classmethod + def params(cls, config: Optional[CfgNode]=None) -> CfgNode: + # network architecture + default = CfgNode() + # allow add new item when merge_with_file + default.cmvn_file = "" + default.cmvn_file_type = "json" + default.input_dim = 0 + default.output_dim = 0 + # encoder related + default.encoder = 'transformer' + default.encoder_conf = CfgNode( + dict( + output_size=256, # dimension of attention + attention_heads=4, + linear_units=2048, # the number of units of position-wise feed forward + num_blocks=12, # the number of encoder blocks + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer='conv2d', # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before=True, + # use_cnn_module=True, + # cnn_module_kernel=15, + # activation_type='swish', + # pos_enc_layer_type='rel_pos', + # selfattention_layer_type='rel_selfattn', + )) + # decoder related + default.decoder = 'transformer' + default.decoder_conf = CfgNode( + dict( + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + self_attention_dropout_rate=0.0, + src_attention_dropout_rate=0.0, )) + # hybrid CTC/attention + default.model_conf = CfgNode( + dict( + asr_weight=0.0, + ctc_weight=0.0, + lsm_weight=0.1, # label smoothing option + length_normalized_loss=False, )) + + if config is not None: + config.merge_from_other_cfg(default) + return default + + def __init__(self, + vocab_size: int, + encoder: TransformerEncoder, + st_decoder: TransformerDecoder, + decoder: TransformerDecoder=None, + ctc: CTCDecoder=None, + ctc_weight: float=0.0, + asr_weight: float=0.0, + ignore_id: int=IGNORE_ID, + lsm_weight: float=0.0, + length_normalized_loss: bool=False): + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.asr_weight = asr_weight + + self.encoder = encoder + self.st_decoder = st_decoder + self.decoder = decoder + self.ctc = ctc + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, ) + + def forward( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + asr_text: paddle.Tensor=None, + asr_text_lengths: paddle.Tensor=None, + ) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[ + paddle.Tensor]]: + """Frontend + Encoder + Decoder + Calc loss + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + Returns: + total_loss, attention_loss, ctc_loss + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == + text_lengths.shape[0]), (speech.shape, speech_lengths.shape, + text.shape, text_lengths.shape) + # 1. Encoder + start = time.time() + encoder_out, encoder_mask = self.encoder(speech, speech_lengths) + encoder_time = time.time() - start + #logger.debug(f"encoder time: {encoder_time}") + #TODO(Hui Zhang): sum not support bool type + #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] + encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( + 1) #[B, 1, T] -> [B] + + # 2a. ST-decoder branch + start = time.time() + loss_st, acc_st = self._calc_st_loss(encoder_out, encoder_mask, text, + text_lengths) + decoder_time = time.time() - start + + loss_asr_att = None + loss_asr_ctc = None + # 2b. ASR Attention-decoder branch + if self.asr_weight > 0.: + if self.ctc_weight != 1.0: + start = time.time() + loss_asr_att, acc_att = self._calc_att_loss( + encoder_out, encoder_mask, asr_text, asr_text_lengths) + decoder_time = time.time() - start + + # 2c. CTC branch + if self.ctc_weight != 0.0: + start = time.time() + loss_asr_ctc = self.ctc(encoder_out, encoder_out_lens, asr_text, + asr_text_lengths) + ctc_time = time.time() - start + + if loss_asr_ctc is None: + loss_asr = loss_asr_att + elif loss_asr_att is None: + loss_asr = loss_asr_ctc + else: + loss_asr = self.ctc_weight * loss_asr_ctc + (1 - self.ctc_weight + ) * loss_asr_att + loss = self.asr_weight * loss_asr + (1 - self.asr_weight) * loss_st + else: + loss = loss_st + return loss, loss_st, loss_asr_att, loss_asr_ctc + + def _calc_st_loss( + self, + encoder_out: paddle.Tensor, + encoder_mask: paddle.Tensor, + ys_pad: paddle.Tensor, + ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]: + """Calc attention loss. + + Args: + encoder_out (paddle.Tensor): [B, Tmax, D] + encoder_mask (paddle.Tensor): [B, 1, Tmax] + ys_pad (paddle.Tensor): [B, Umax] + ys_pad_lens (paddle.Tensor): [B] + + Returns: + Tuple[paddle.Tensor, float]: attention_loss, accuracy rate + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad, + ys_in_lens) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, ) + return loss_att, acc_att + + def _calc_att_loss( + self, + encoder_out: paddle.Tensor, + encoder_mask: paddle.Tensor, + ys_pad: paddle.Tensor, + ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]: + """Calc attention loss. + + Args: + encoder_out (paddle.Tensor): [B, Tmax, D] + encoder_mask (paddle.Tensor): [B, 1, Tmax] + ys_pad (paddle.Tensor): [B, Umax] + ys_pad_lens (paddle.Tensor): [B] + + Returns: + Tuple[paddle.Tensor, float]: attention_loss, accuracy rate + """ + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, + self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad, + ys_in_lens) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, ) + return loss_att, acc_att + + def _forward_encoder( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Encoder pass. + + Args: + speech (paddle.Tensor): [B, Tmax, D] + speech_lengths (paddle.Tensor): [B] + decoding_chunk_size (int, optional): chuck size. Defaults to -1. + num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1. + simulate_streaming (bool, optional): streaming or not. Defaults to False. + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: + encoder hiddens (B, Tmax, D), + encoder hiddens mask (B, 1, Tmax). + """ + # Let's assume B = batch_size + # 1. Encoder + if simulate_streaming and decoding_chunk_size > 0: + encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk( + speech, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + else: + encoder_out, encoder_mask = self.encoder( + speech, + speech_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks + ) # (B, maxlen, encoder_dim) + return encoder_out, encoder_mask + + def translate( + self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + beam_size: int=10, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False, ) -> paddle.Tensor: + """ Apply beam search on attention decoder + Args: + speech (paddle.Tensor): (batch, max_len, feat_dim) + speech_length (paddle.Tensor): (batch, ) + beam_size (int): beam size for beam search + decoding_chunk_size (int): decoding chunk for dynamic chunk + trained model. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here + simulate_streaming (bool): whether do encoder forward in a + streaming fashion + Returns: + paddle.Tensor: decoding result, (batch, max_result_len) + """ + assert speech.shape[0] == speech_lengths.shape[0] + assert decoding_chunk_size != 0 + device = speech.place + batch_size = speech.shape[0] + + # Let's assume B = batch_size and N = beam_size + # 1. Encoder + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + encoder_dim = encoder_out.size(2) + running_size = batch_size * beam_size + encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( + running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) + encoder_mask = encoder_mask.unsqueeze(1).repeat( + 1, beam_size, 1, 1).view(running_size, 1, + maxlen) # (B*N, 1, max_len) + + hyps = paddle.ones( + [running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1) + # log scale score + scores = paddle.to_tensor( + [0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float) + scores = scores.to(device).repeat(batch_size).unsqueeze(1).to( + device) # (B*N, 1) + end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1) + cache: Optional[List[paddle.Tensor]] = None + # 2. Decoder forward step by step + for i in range(1, maxlen + 1): + # Stop if all batch and all beam produce eos + # TODO(Hui Zhang): if end_flag.sum() == running_size: + if end_flag.cast(paddle.int64).sum() == running_size: + break + + # 2.1 Forward decoder step + hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( + running_size, 1, 1).to(device) # (B*N, i, i) + # logp: (B*N, vocab) + logp, cache = self.st_decoder.forward_one_step( + encoder_out, encoder_mask, hyps, hyps_mask, cache) + + # 2.2 First beam prune: select topk best prob at current time + top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) + top_k_logp = mask_finished_scores(top_k_logp, end_flag) + top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) + + # 2.3 Seconde beam prune: select topk score with history + scores = scores + top_k_logp # (B*N, N), broadcast add + scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) + scores, offset_k_index = scores.topk(k=beam_size) # (B, N) + scores = scores.view(-1, 1) # (B*N, 1) + + # 2.4. Compute base index in top_k_index, + # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), + # then find offset_k_index in top_k_index + base_k_index = paddle.arange(batch_size).view(-1, 1).repeat( + 1, beam_size) # (B, N) + base_k_index = base_k_index * beam_size * beam_size + best_k_index = base_k_index.view(-1) + offset_k_index.view( + -1) # (B*N) + + # 2.5 Update best hyps + best_k_pred = paddle.index_select( + top_k_index.view(-1), index=best_k_index, axis=0) # (B*N) + best_hyps_index = best_k_index // beam_size + last_best_k_hyps = paddle.index_select( + hyps, index=best_hyps_index, axis=0) # (B*N, i) + hyps = paddle.cat( + (last_best_k_hyps, best_k_pred.view(-1, 1)), + dim=1) # (B*N, i+1) + + # 2.6 Update end flag + end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1) + + # 3. Select best of best + scores = scores.view(batch_size, beam_size) + # TODO: length normalization + best_index = paddle.argmax(scores, axis=-1).long() # (B) + best_hyps_index = best_index + paddle.arange( + batch_size, dtype=paddle.long) * beam_size + best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0) + best_hyps = best_hyps[:, 1:] + return best_hyps + + @jit.export + def subsampling_rate(self) -> int: + """ Export interface for c++ call, return subsampling_rate of the + model + """ + return self.encoder.embed.subsampling_rate + + @jit.export + def right_context(self) -> int: + """ Export interface for c++ call, return right_context of the model + """ + return self.encoder.embed.right_context + + @jit.export + def sos_symbol(self) -> int: + """ Export interface for c++ call, return sos symbol id of the model + """ + return self.sos + + @jit.export + def eos_symbol(self) -> int: + """ Export interface for c++ call, return eos symbol id of the model + """ + return self.eos + + @jit.export + def forward_encoder_chunk( + self, + xs: paddle.Tensor, + offset: int, + required_cache_size: int, + subsampling_cache: Optional[paddle.Tensor]=None, + elayers_output_cache: Optional[List[paddle.Tensor]]=None, + conformer_cnn_cache: Optional[List[paddle.Tensor]]=None, + ) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[ + paddle.Tensor]]: + """ Export interface for c++ call, give input chunk xs, and return + output from time 0 to current chunk. + Args: + xs (paddle.Tensor): chunk input + subsampling_cache (Optional[paddle.Tensor]): subsampling cache + elayers_output_cache (Optional[List[paddle.Tensor]]): + transformer/conformer encoder layers output cache + conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer + cnn cache + Returns: + paddle.Tensor: output, it ranges from time 0 to current chunk. + paddle.Tensor: subsampling cache + List[paddle.Tensor]: attention cache + List[paddle.Tensor]: conformer cnn cache + """ + return self.encoder.forward_chunk( + xs, offset, required_cache_size, subsampling_cache, + elayers_output_cache, conformer_cnn_cache) + + @jit.export + def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: + """ Export interface for c++ call, apply linear transform and log + softmax before ctc + Args: + xs (paddle.Tensor): encoder output + Returns: + paddle.Tensor: activation before ctc + """ + return self.ctc.log_softmax(xs) + + @jit.export + def forward_attention_decoder( + self, + hyps: paddle.Tensor, + hyps_lens: paddle.Tensor, + encoder_out: paddle.Tensor, ) -> paddle.Tensor: + """ Export interface for c++ call, forward decoder with multiple + hypothesis from ctc prefix beam search and one encoder output + Args: + hyps (paddle.Tensor): hyps from ctc prefix beam search, already + pad sos at the begining, (B, T) + hyps_lens (paddle.Tensor): length of each hyp in hyps, (B) + encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D) + Returns: + paddle.Tensor: decoder output, (B, L) + """ + assert encoder_out.size(0) == 1 + num_hyps = hyps.size(0) + assert hyps_lens.size(0) == num_hyps + encoder_out = encoder_out.repeat(num_hyps, 1, 1) + # (B, 1, T) + encoder_mask = paddle.ones( + [num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool) + # (num_hyps, max_hyps_len, vocab_size) + decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, + hyps_lens) + decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1) + return decoder_out + + @paddle.no_grad() + def decode(self, + feats: paddle.Tensor, + feats_lengths: paddle.Tensor, + text_feature: Dict[str, int], + decoding_method: str, + lang_model_path: str, + beam_alpha: float, + beam_beta: float, + beam_size: int, + cutoff_prob: float, + cutoff_top_n: int, + num_processes: int, + ctc_weight: float=0.0, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False): + """u2 decoding. + + Args: + feats (Tenosr): audio features, (B, T, D) + feats_lengths (Tenosr): (B) + text_feature (TextFeaturizer): text feature object. + decoding_method (str): decoding mode, e.g. + 'fullsentence', + 'simultaneous' + lang_model_path (str): lm path. + beam_alpha (float): lm weight. + beam_beta (float): length penalty. + beam_size (int): beam size for search + cutoff_prob (float): for prune. + cutoff_top_n (int): for prune. + num_processes (int): + ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0. + decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here. + num_decoding_left_chunks (int, optional): + number of left chunks for decoding. Defaults to -1. + simulate_streaming (bool, optional): simulate streaming inference. Defaults to False. + + Raises: + ValueError: when not support decoding_method. + + Returns: + List[List[int]]: transcripts. + """ + batch_size = feats.size(0) + + if decoding_method == 'fullsentence': + hyps = self.translate( + feats, + feats_lengths, + beam_size=beam_size, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) + hyps = [hyp.tolist() for hyp in hyps] + else: + raise ValueError(f"Not support decoding method: {decoding_method}") + + res = [text_feature.defeaturize(hyp) for hyp in hyps] + return res + + +class U2STModel(U2STBaseModel): + def __init__(self, configs: dict): + vocab_size, encoder, decoder = U2STModel._init_from_config(configs) + + if isinstance(decoder, Tuple): + st_decoder, asr_decoder, ctc = decoder + super().__init__( + vocab_size=vocab_size, + encoder=encoder, + st_decoder=st_decoder, + decoder=asr_decoder, + ctc=ctc, + **configs['model_conf']) + else: + super().__init__( + vocab_size=vocab_size, + encoder=encoder, + st_decoder=decoder, + **configs['model_conf']) + + @classmethod + def _init_from_config(cls, configs: dict): + """init sub module for model. + + Args: + configs (dict): config dict. + + Raises: + ValueError: raise when using not support encoder type. + + Returns: + int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc + """ + if configs['cmvn_file'] is not None: + mean, istd = load_cmvn(configs['cmvn_file'], + configs['cmvn_file_type']) + global_cmvn = GlobalCMVN( + paddle.to_tensor(mean, dtype=paddle.float), + paddle.to_tensor(istd, dtype=paddle.float)) + else: + global_cmvn = None + + input_dim = configs['input_dim'] + vocab_size = configs['output_dim'] + assert input_dim != 0, input_dim + assert vocab_size != 0, vocab_size + + encoder_type = configs.get('encoder', 'transformer') + logger.info(f"U2 Encoder type: {encoder_type}") + if encoder_type == 'transformer': + encoder = TransformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + elif encoder_type == 'conformer': + encoder = ConformerEncoder( + input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + else: + raise ValueError(f"not support encoder type:{encoder_type}") + + st_decoder = TransformerDecoder(vocab_size, + encoder.output_size(), + **configs['decoder_conf']) + + asr_weight = configs['model_conf']['asr_weight'] + logger.info(f"ASR Joint Training Weight: {asr_weight}") + + if asr_weight > 0.: + decoder = TransformerDecoder(vocab_size, + encoder.output_size(), + **configs['decoder_conf']) + ctc = CTCDecoder( + odim=vocab_size, + enc_n_units=encoder.output_size(), + blank_id=0, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True) # sum / batch_size + + return vocab_size, encoder, (st_decoder, decoder, ctc) + else: + return vocab_size, encoder, st_decoder + + @classmethod + def from_config(cls, configs: dict): + """init model. + + Args: + configs (dict): config dict. + + Raises: + ValueError: raise when using not support encoder type. + + Returns: + nn.Layer: U2STModel + """ + model = cls(configs) + return model + + @classmethod + def from_pretrained(cls, dataloader, config, checkpoint_path): + """Build a DeepSpeech2Model model from a pretrained model. + + Args: + dataloader (paddle.io.DataLoader): not used. + config (yacs.config.CfgNode): model configs + checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name + + Returns: + DeepSpeech2Model: The model built from pretrained result. + """ + config.defrost() + config.input_dim = dataloader.collate_fn.feature_size + config.output_dim = dataloader.collate_fn.vocab_size + config.freeze() + model = cls.from_config(config) + + if checkpoint_path: + infos = checkpoint.load_parameters( + model, checkpoint_path=checkpoint_path) + logger.info(f"checkpoint info: {infos}") + layer_tools.summary(model) + return model + + +class U2STInferModel(U2STModel): + def __init__(self, configs: dict): + super().__init__(configs) + + def forward(self, + feats, + feats_lengths, + decoding_chunk_size=-1, + num_decoding_left_chunks=-1, + simulate_streaming=False): + """export model function + + Args: + feats (Tensor): [B, T, D] + feats_lengths (Tensor): [B] + + Returns: + List[List[int]]: best path result + """ + return self.translate( + feats, + feats_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) diff --git a/deepspeech/utils/bleu_score.py b/deepspeech/utils/bleu_score.py new file mode 100644 index 00000000..580fbf61 --- /dev/null +++ b/deepspeech/utils/bleu_score.py @@ -0,0 +1,53 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module provides functions to calculate bleu score in different level. +e.g. wer for word-level, cer for char-level. +""" +import numpy as np +import sacrebleu + +__all__ = ['bleu', 'char_bleu'] + + +def bleu(hypothesis, reference): + """Calculate BLEU. BLEU compares reference text and + hypothesis text in word-level using scarebleu. + + + + :param reference: The reference sentences. + :type reference: list[list[str]] + :param hypothesis: The hypothesis sentence. + :type hypothesis: list[str] + :raises ValueError: If the reference length is zero. + """ + + return sacrebleu.corpus_bleu(hypothesis, reference) + +def char_bleu(hypothesis, reference): + """Calculate BLEU. BLEU compares reference text and + hypothesis text in char-level using scarebleu. + + + + :param reference: The reference sentences. + :type reference: list[list[str]] + :param hypothesis: The hypothesis sentence. + :type hypothesis: list[str] + :raises ValueError: If the reference number is zero. + """ + hypothesis =[' '.join(list(hyp.replace(' ', ''))) for hyp in hypothesis] + reference = [[' '.join(list(ref_i.replace(' ', ''))) for ref_i in ref ]for ref in reference ] + + return sacrebleu.corpus_bleu(hypothesis, reference) \ No newline at end of file