parent
0323151912
commit
ac0ae57ef2
@ -0,0 +1,666 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
from collections import namedtuple
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import kaldiio
|
||||
import numpy as np
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
|
||||
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
|
||||
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||
from deepspeech.frontend.normalizer import FeatureNormalizer
|
||||
from deepspeech.frontend.speech import SpeechSegment
|
||||
from deepspeech.frontend.utility import IGNORE_ID
|
||||
from deepspeech.io.utility import pad_sequence
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
__all__ = ["SpeechCollator", "KaldiPrePorocessedCollator"]
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
# namedtupe need global for pickle.
|
||||
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
|
||||
|
||||
|
||||
class SpeechCollator():
|
||||
@classmethod
|
||||
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
||||
default = CfgNode(
|
||||
dict(
|
||||
augmentation_config="",
|
||||
random_seed=0,
|
||||
mean_std_filepath="",
|
||||
unit_type="char",
|
||||
vocab_filepath="",
|
||||
spm_model_prefix="",
|
||||
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
|
||||
feat_dim=0, # 'mfcc', 'fbank'
|
||||
delta_delta=False, # 'mfcc', 'fbank'
|
||||
stride_ms=10.0, # ms
|
||||
window_ms=20.0, # ms
|
||||
n_fft=None, # fft points
|
||||
max_freq=None, # None for samplerate/2
|
||||
target_sample_rate=16000, # target sample rate
|
||||
use_dB_normalization=True,
|
||||
target_dB=-20,
|
||||
dither=1.0, # feature dither
|
||||
keep_transcription_text=False))
|
||||
|
||||
if config is not None:
|
||||
config.merge_from_other_cfg(default)
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
"""Build a SpeechCollator object from a config.
|
||||
|
||||
Args:
|
||||
config (yacs.config.CfgNode): configs object.
|
||||
|
||||
Returns:
|
||||
SpeechCollator: collator object.
|
||||
"""
|
||||
assert 'augmentation_config' in config.collator
|
||||
assert 'keep_transcription_text' in config.collator
|
||||
assert 'mean_std_filepath' in config.collator
|
||||
assert 'vocab_filepath' in config.collator
|
||||
assert 'specgram_type' in config.collator
|
||||
assert 'n_fft' in config.collator
|
||||
assert config.collator
|
||||
|
||||
if isinstance(config.collator.augmentation_config, (str, bytes)):
|
||||
if config.collator.augmentation_config:
|
||||
aug_file = io.open(
|
||||
config.collator.augmentation_config,
|
||||
mode='r',
|
||||
encoding='utf8')
|
||||
else:
|
||||
aug_file = io.StringIO(initial_value='{}', newline='')
|
||||
else:
|
||||
aug_file = config.collator.augmentation_config
|
||||
assert isinstance(aug_file, io.StringIO)
|
||||
|
||||
speech_collator = cls(
|
||||
aug_file=aug_file,
|
||||
random_seed=0,
|
||||
mean_std_filepath=config.collator.mean_std_filepath,
|
||||
unit_type=config.collator.unit_type,
|
||||
vocab_filepath=config.collator.vocab_filepath,
|
||||
spm_model_prefix=config.collator.spm_model_prefix,
|
||||
specgram_type=config.collator.specgram_type,
|
||||
feat_dim=config.collator.feat_dim,
|
||||
delta_delta=config.collator.delta_delta,
|
||||
stride_ms=config.collator.stride_ms,
|
||||
window_ms=config.collator.window_ms,
|
||||
n_fft=config.collator.n_fft,
|
||||
max_freq=config.collator.max_freq,
|
||||
target_sample_rate=config.collator.target_sample_rate,
|
||||
use_dB_normalization=config.collator.use_dB_normalization,
|
||||
target_dB=config.collator.target_dB,
|
||||
dither=config.collator.dither,
|
||||
keep_transcription_text=config.collator.keep_transcription_text)
|
||||
return speech_collator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aug_file,
|
||||
mean_std_filepath,
|
||||
vocab_filepath,
|
||||
spm_model_prefix,
|
||||
random_seed=0,
|
||||
unit_type="char",
|
||||
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
|
||||
feat_dim=0, # 'mfcc', 'fbank'
|
||||
delta_delta=False, # 'mfcc', 'fbank'
|
||||
stride_ms=10.0, # ms
|
||||
window_ms=20.0, # ms
|
||||
n_fft=None, # fft points
|
||||
max_freq=None, # None for samplerate/2
|
||||
target_sample_rate=16000, # target sample rate
|
||||
use_dB_normalization=True,
|
||||
target_dB=-20,
|
||||
dither=1.0,
|
||||
keep_transcription_text=True):
|
||||
"""SpeechCollator Collator
|
||||
|
||||
Args:
|
||||
unit_type(str): token unit type, e.g. char, word, spm
|
||||
vocab_filepath (str): vocab file path.
|
||||
mean_std_filepath (str): mean and std file path, which suffix is *.npy
|
||||
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
|
||||
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
|
||||
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
|
||||
window_ms (float, optional): window size in ms. Defaults to 20.0.
|
||||
n_fft (int, optional): fft points for rfft. Defaults to None.
|
||||
max_freq (int, optional): max cut freq. Defaults to None.
|
||||
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
|
||||
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
|
||||
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
|
||||
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
|
||||
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
|
||||
target_dB (int, optional): target dB. Defaults to -20.
|
||||
random_seed (int, optional): for random generator. Defaults to 0.
|
||||
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
|
||||
if ``keep_transcription_text`` is False, text is token ids else is raw string.
|
||||
|
||||
Do augmentations
|
||||
Padding audio features with zeros to make them have the same shape (or
|
||||
a user-defined shape) within one batch.
|
||||
"""
|
||||
self._keep_transcription_text = keep_transcription_text
|
||||
|
||||
self._local_data = TarLocalData(tar2info={}, tar2object={})
|
||||
self._augmentation_pipeline = AugmentationPipeline(
|
||||
augmentation_config=aug_file.read(), random_seed=random_seed)
|
||||
|
||||
self._normalizer = FeatureNormalizer(
|
||||
mean_std_filepath) if mean_std_filepath else None
|
||||
|
||||
self._stride_ms = stride_ms
|
||||
self._target_sample_rate = target_sample_rate
|
||||
|
||||
self._speech_featurizer = SpeechFeaturizer(
|
||||
unit_type=unit_type,
|
||||
vocab_filepath=vocab_filepath,
|
||||
spm_model_prefix=spm_model_prefix,
|
||||
specgram_type=specgram_type,
|
||||
feat_dim=feat_dim,
|
||||
delta_delta=delta_delta,
|
||||
stride_ms=stride_ms,
|
||||
window_ms=window_ms,
|
||||
n_fft=n_fft,
|
||||
max_freq=max_freq,
|
||||
target_sample_rate=target_sample_rate,
|
||||
use_dB_normalization=use_dB_normalization,
|
||||
target_dB=target_dB,
|
||||
dither=dither)
|
||||
|
||||
def _parse_tar(self, file):
|
||||
"""Parse a tar file to get a tarfile object
|
||||
and a map containing tarinfoes
|
||||
"""
|
||||
result = {}
|
||||
f = tarfile.open(file)
|
||||
for tarinfo in f.getmembers():
|
||||
result[tarinfo.name] = tarinfo
|
||||
return f, result
|
||||
|
||||
def _subfile_from_tar(self, file):
|
||||
"""Get subfile object from tar.
|
||||
|
||||
It will return a subfile object from tar file
|
||||
and cached tar file info for next reading request.
|
||||
"""
|
||||
tarpath, filename = file.split(':', 1)[1].split('#', 1)
|
||||
if 'tar2info' not in self._local_data.__dict__:
|
||||
self._local_data.tar2info = {}
|
||||
if 'tar2object' not in self._local_data.__dict__:
|
||||
self._local_data.tar2object = {}
|
||||
if tarpath not in self._local_data.tar2info:
|
||||
object, infoes = self._parse_tar(tarpath)
|
||||
self._local_data.tar2info[tarpath] = infoes
|
||||
self._local_data.tar2object[tarpath] = object
|
||||
return self._local_data.tar2object[tarpath].extractfile(
|
||||
self._local_data.tar2info[tarpath][filename])
|
||||
|
||||
def process_utterance(self, audio_file, translation):
|
||||
"""Load, augment, featurize and normalize for speech data.
|
||||
|
||||
:param audio_file: Filepath or file object of audio file.
|
||||
:type audio_file: str | file
|
||||
:param translation: translation text.
|
||||
:type translation: str
|
||||
:return: Tuple of audio feature tensor and data of translation part,
|
||||
where translation part could be token ids or text.
|
||||
:rtype: tuple of (2darray, list)
|
||||
"""
|
||||
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
|
||||
speech_segment = SpeechSegment.from_file(
|
||||
self._subfile_from_tar(audio_file), translation)
|
||||
else:
|
||||
speech_segment = SpeechSegment.from_file(audio_file, translation)
|
||||
|
||||
# audio augment
|
||||
self._augmentation_pipeline.transform_audio(speech_segment)
|
||||
|
||||
specgram, translation_part = self._speech_featurizer.featurize(
|
||||
speech_segment, self._keep_transcription_text)
|
||||
if self._normalizer:
|
||||
specgram = self._normalizer.apply(specgram)
|
||||
|
||||
# specgram augment
|
||||
specgram = self._augmentation_pipeline.transform_feature(specgram)
|
||||
specgram = specgram.transpose([1, 0])
|
||||
return specgram, translation_part
|
||||
|
||||
def __call__(self, batch):
|
||||
"""batch examples
|
||||
|
||||
Args:
|
||||
batch ([List]): batch is (audio, text)
|
||||
audio (np.ndarray) shape (D, T)
|
||||
text (List[int] or str): shape (U,)
|
||||
|
||||
Returns:
|
||||
tuple(audio, text, audio_lens, text_lens): batched data.
|
||||
audio : (B, Tmax, D)
|
||||
audio_lens: (B)
|
||||
text : (B, Umax)
|
||||
text_lens: (B)
|
||||
"""
|
||||
audios = []
|
||||
audio_lens = []
|
||||
texts = []
|
||||
text_lens = []
|
||||
utts = []
|
||||
for utt, audio, text in batch:
|
||||
audio, text = self.process_utterance(audio, text)
|
||||
#utt
|
||||
utts.append(utt)
|
||||
# audio
|
||||
audios.append(audio) # [T, D]
|
||||
audio_lens.append(audio.shape[0])
|
||||
# text
|
||||
# for training, text is token ids
|
||||
# else text is string, convert to unicode ord
|
||||
tokens = []
|
||||
if self._keep_transcription_text:
|
||||
assert isinstance(text, str), (type(text), text)
|
||||
tokens = [ord(t) for t in text]
|
||||
else:
|
||||
tokens = text # token ids
|
||||
tokens = tokens if isinstance(tokens, np.ndarray) else np.array(
|
||||
tokens, dtype=np.int64)
|
||||
texts.append(tokens)
|
||||
text_lens.append(tokens.shape[0])
|
||||
|
||||
padded_audios = pad_sequence(
|
||||
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
|
||||
audio_lens = np.array(audio_lens).astype(np.int64)
|
||||
padded_texts = pad_sequence(
|
||||
texts, padding_value=IGNORE_ID).astype(np.int64)
|
||||
text_lens = np.array(text_lens).astype(np.int64)
|
||||
return utts, padded_audios, audio_lens, padded_texts, text_lens
|
||||
|
||||
@property
|
||||
def manifest(self):
|
||||
return self._manifest
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self._speech_featurizer.vocab_size
|
||||
|
||||
@property
|
||||
def vocab_list(self):
|
||||
return self._speech_featurizer.vocab_list
|
||||
|
||||
@property
|
||||
def vocab_dict(self):
|
||||
return self._speech_featurizer.vocab_dict
|
||||
|
||||
@property
|
||||
def text_feature(self):
|
||||
return self._speech_featurizer.text_feature
|
||||
|
||||
@property
|
||||
def feature_size(self):
|
||||
return self._speech_featurizer.feature_size
|
||||
|
||||
@property
|
||||
def stride_ms(self):
|
||||
return self._speech_featurizer.stride_ms
|
||||
|
||||
|
||||
class TripletSpeechCollator(SpeechCollator):
|
||||
def process_utterance(self, audio_file, translation, transcript):
|
||||
"""Load, augment, featurize and normalize for speech data.
|
||||
|
||||
:param audio_file: Filepath or file object of audio file.
|
||||
:type audio_file: str | file
|
||||
:param translation: translation text.
|
||||
:type translation: str
|
||||
:return: Tuple of audio feature tensor and data of translation part,
|
||||
where translation part could be token ids or text.
|
||||
:rtype: tuple of (2darray, list)
|
||||
"""
|
||||
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
|
||||
speech_segment = SpeechSegment.from_file(
|
||||
self._subfile_from_tar(audio_file), translation)
|
||||
else:
|
||||
speech_segment = SpeechSegment.from_file(audio_file, translation)
|
||||
|
||||
# audio augment
|
||||
self._augmentation_pipeline.transform_audio(speech_segment)
|
||||
|
||||
specgram, translation_part = self._speech_featurizer.featurize(
|
||||
speech_segment, self._keep_transcription_text)
|
||||
transcript_part = self._speech_featurizer._text_featurizer.featurize(
|
||||
transcript)
|
||||
if self._normalizer:
|
||||
specgram = self._normalizer.apply(specgram)
|
||||
|
||||
# specgram augment
|
||||
specgram = self._augmentation_pipeline.transform_feature(specgram)
|
||||
specgram = specgram.transpose([1, 0])
|
||||
return specgram, translation_part, transcript_part
|
||||
|
||||
def __call__(self, batch):
|
||||
"""batch examples
|
||||
|
||||
Args:
|
||||
batch ([List]): batch is (audio, text)
|
||||
audio (np.ndarray) shape (D, T)
|
||||
text (List[int] or str): shape (U,)
|
||||
|
||||
Returns:
|
||||
tuple(audio, text, audio_lens, text_lens): batched data.
|
||||
audio : (B, Tmax, D)
|
||||
audio_lens: (B)
|
||||
text : (B, Umax)
|
||||
text_lens: (B)
|
||||
"""
|
||||
audios = []
|
||||
audio_lens = []
|
||||
translation_text = []
|
||||
translation_text_lens = []
|
||||
transcription_text = []
|
||||
transcription_text_lens = []
|
||||
|
||||
utts = []
|
||||
for utt, audio, translation, transcription in batch:
|
||||
audio, translation, transcription = self.process_utterance(
|
||||
audio, translation, transcription)
|
||||
#utt
|
||||
utts.append(utt)
|
||||
# audio
|
||||
audios.append(audio) # [T, D]
|
||||
audio_lens.append(audio.shape[0])
|
||||
# text
|
||||
# for training, text is token ids
|
||||
# else text is string, convert to unicode ord
|
||||
tokens = [[], []]
|
||||
for idx, text in enumerate([translation, transcription]):
|
||||
if self._keep_transcription_text:
|
||||
assert isinstance(text, str), (type(text), text)
|
||||
tokens[idx] = [ord(t) for t in text]
|
||||
else:
|
||||
tokens[idx] = text # token ids
|
||||
tokens[idx] = tokens[idx] if isinstance(
|
||||
tokens[idx], np.ndarray) else np.array(
|
||||
tokens[idx], dtype=np.int64)
|
||||
translation_text.append(tokens[0])
|
||||
translation_text_lens.append(tokens[0].shape[0])
|
||||
transcription_text.append(tokens[1])
|
||||
transcription_text_lens.append(tokens[1].shape[0])
|
||||
|
||||
padded_audios = pad_sequence(
|
||||
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
|
||||
audio_lens = np.array(audio_lens).astype(np.int64)
|
||||
padded_translation = pad_sequence(
|
||||
translation_text, padding_value=IGNORE_ID).astype(np.int64)
|
||||
translation_lens = np.array(translation_text_lens).astype(np.int64)
|
||||
padded_transcription = pad_sequence(
|
||||
transcription_text, padding_value=IGNORE_ID).astype(np.int64)
|
||||
transcription_lens = np.array(transcription_text_lens).astype(np.int64)
|
||||
return utts, padded_audios, audio_lens, (
|
||||
padded_translation, padded_transcription), (translation_lens,
|
||||
transcription_lens)
|
||||
|
||||
|
||||
class KaldiPrePorocessedCollator(SpeechCollator):
|
||||
@classmethod
|
||||
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
||||
default = CfgNode(
|
||||
dict(
|
||||
augmentation_config="",
|
||||
random_seed=0,
|
||||
unit_type="char",
|
||||
vocab_filepath="",
|
||||
spm_model_prefix="",
|
||||
feat_dim=0,
|
||||
stride_ms=10.0,
|
||||
keep_transcription_text=False))
|
||||
|
||||
if config is not None:
|
||||
config.merge_from_other_cfg(default)
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
"""Build a SpeechCollator object from a config.
|
||||
|
||||
Args:
|
||||
config (yacs.config.CfgNode): configs object.
|
||||
|
||||
Returns:
|
||||
SpeechCollator: collator object.
|
||||
"""
|
||||
assert 'augmentation_config' in config.collator
|
||||
assert 'keep_transcription_text' in config.collator
|
||||
assert 'vocab_filepath' in config.collator
|
||||
assert config.collator
|
||||
|
||||
if isinstance(config.collator.augmentation_config, (str, bytes)):
|
||||
if config.collator.augmentation_config:
|
||||
aug_file = io.open(
|
||||
config.collator.augmentation_config,
|
||||
mode='r',
|
||||
encoding='utf8')
|
||||
else:
|
||||
aug_file = io.StringIO(initial_value='{}', newline='')
|
||||
else:
|
||||
aug_file = config.collator.augmentation_config
|
||||
assert isinstance(aug_file, io.StringIO)
|
||||
|
||||
speech_collator = cls(
|
||||
aug_file=aug_file,
|
||||
random_seed=0,
|
||||
unit_type=config.collator.unit_type,
|
||||
vocab_filepath=config.collator.vocab_filepath,
|
||||
spm_model_prefix=config.collator.spm_model_prefix,
|
||||
feat_dim=config.collator.feat_dim,
|
||||
stride_ms=config.collator.stride_ms,
|
||||
keep_transcription_text=config.collator.keep_transcription_text)
|
||||
return speech_collator
|
||||
|
||||
def __init__(self,
|
||||
aug_file,
|
||||
vocab_filepath,
|
||||
spm_model_prefix,
|
||||
random_seed=0,
|
||||
unit_type="char",
|
||||
feat_dim=0,
|
||||
stride_ms=10.0,
|
||||
keep_transcription_text=True):
|
||||
"""SpeechCollator Collator
|
||||
|
||||
Args:
|
||||
unit_type(str): token unit type, e.g. char, word, spm
|
||||
vocab_filepath (str): vocab file path.
|
||||
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
|
||||
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
|
||||
random_seed (int, optional): for random generator. Defaults to 0.
|
||||
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
|
||||
if ``keep_transcription_text`` is False, text is token ids else is raw string.
|
||||
|
||||
Do augmentations
|
||||
Padding audio features with zeros to make them have the same shape (or
|
||||
a user-defined shape) within one batch.
|
||||
"""
|
||||
self._keep_transcription_text = keep_transcription_text
|
||||
self._feat_dim = feat_dim
|
||||
self._stride_ms = stride_ms
|
||||
|
||||
self._local_data = TarLocalData(tar2info={}, tar2object={})
|
||||
self._augmentation_pipeline = AugmentationPipeline(
|
||||
augmentation_config=aug_file.read(), random_seed=random_seed)
|
||||
|
||||
self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath,
|
||||
spm_model_prefix)
|
||||
|
||||
def process_utterance(self, audio_file, translation):
|
||||
"""Load, augment, featurize and normalize for speech data.
|
||||
|
||||
:param audio_file: Filepath or file object of kaldi processed feature.
|
||||
:type audio_file: str | file
|
||||
:param translation: Translation text.
|
||||
:type translation: str
|
||||
:return: Tuple of audio feature tensor and data of translation part,
|
||||
where translation part could be token ids or text.
|
||||
:rtype: tuple of (2darray, list)
|
||||
"""
|
||||
specgram = kaldiio.load_mat(audio_file)
|
||||
specgram = specgram.transpose([1, 0])
|
||||
assert specgram.shape[
|
||||
0] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
|
||||
self._feat_dim, specgram.shape[0])
|
||||
|
||||
# specgram augment
|
||||
specgram = self._augmentation_pipeline.transform_feature(specgram)
|
||||
|
||||
specgram = specgram.transpose([1, 0])
|
||||
if self._keep_transcription_text:
|
||||
return specgram, translation
|
||||
else:
|
||||
text_ids = self._text_featurizer.featurize(translation)
|
||||
return specgram, text_ids
|
||||
|
||||
@property
|
||||
def manifest(self):
|
||||
return self._manifest
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self._text_featurizer.vocab_size
|
||||
|
||||
@property
|
||||
def vocab_list(self):
|
||||
return self._text_featurizer.vocab_list
|
||||
|
||||
@property
|
||||
def vocab_dict(self):
|
||||
return self._text_featurizer.vocab_dict
|
||||
|
||||
@property
|
||||
def text_feature(self):
|
||||
return self._text_featurizer
|
||||
|
||||
@property
|
||||
def feature_size(self):
|
||||
return self._feat_dim
|
||||
|
||||
@property
|
||||
def stride_ms(self):
|
||||
return self._stride_ms
|
||||
|
||||
|
||||
class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
|
||||
def process_utterance(self, audio_file, translation, transcript):
|
||||
"""Load, augment, featurize and normalize for speech data.
|
||||
|
||||
:param audio_file: Filepath or file object of kali processed feature.
|
||||
:type audio_file: str | file
|
||||
:param translation: Translation text.
|
||||
:type translation: str
|
||||
:param transcript: Transcription text.
|
||||
:type transcript: str
|
||||
:return: Tuple of audio feature tensor and data of translation and transcription parts,
|
||||
where translation and transcription parts could be token ids or text.
|
||||
:rtype: tuple of (2darray, (list, list))
|
||||
"""
|
||||
specgram = kaldiio.load_mat(audio_file)
|
||||
specgram = specgram.transpose([1, 0])
|
||||
assert specgram.shape[
|
||||
0] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
|
||||
self._feat_dim, specgram.shape[0])
|
||||
|
||||
# specgram augment
|
||||
specgram = self._augmentation_pipeline.transform_feature(specgram)
|
||||
|
||||
specgram = specgram.transpose([1, 0])
|
||||
if self._keep_transcription_text:
|
||||
return specgram, translation, transcript
|
||||
else:
|
||||
translation_text_ids = self._text_featurizer.featurize(translation)
|
||||
transcript_text_ids = self._text_featurizer.featurize(transcript)
|
||||
return specgram, translation_text_ids, transcript_text_ids
|
||||
|
||||
def __call__(self, batch):
|
||||
"""batch examples
|
||||
|
||||
Args:
|
||||
batch ([List]): batch is (audio, text)
|
||||
audio (np.ndarray) shape (D, T)
|
||||
translation (List[int] or str): shape (U,)
|
||||
transcription (List[int] or str): shape (V,)
|
||||
|
||||
Returns:
|
||||
tuple(audio, text, audio_lens, text_lens): batched data.
|
||||
audio : (B, Tmax, D)
|
||||
audio_lens: (B)
|
||||
translation_text : (B, Umax)
|
||||
translation_text_lens: (B)
|
||||
transcription_text : (B, Vmax)
|
||||
transcription_text_lens: (B)
|
||||
"""
|
||||
audios = []
|
||||
audio_lens = []
|
||||
translation_text = []
|
||||
translation_text_lens = []
|
||||
transcription_text = []
|
||||
transcription_text_lens = []
|
||||
|
||||
utts = []
|
||||
for utt, audio, translation, transcription in batch:
|
||||
audio, translation, transcription = self.process_utterance(
|
||||
audio, translation, transcription)
|
||||
#utt
|
||||
utts.append(utt)
|
||||
# audio
|
||||
audios.append(audio) # [T, D]
|
||||
audio_lens.append(audio.shape[0])
|
||||
# text
|
||||
# for training, text is token ids
|
||||
# else text is string, convert to unicode ord
|
||||
tokens = [[], []]
|
||||
for idx, text in enumerate([translation, transcription]):
|
||||
if self._keep_transcription_text:
|
||||
assert isinstance(text, str), (type(text), text)
|
||||
tokens[idx] = [ord(t) for t in text]
|
||||
else:
|
||||
tokens[idx] = text # token ids
|
||||
tokens[idx] = tokens[idx] if isinstance(
|
||||
tokens[idx], np.ndarray) else np.array(
|
||||
tokens[idx], dtype=np.int64)
|
||||
translation_text.append(tokens[0])
|
||||
translation_text_lens.append(tokens[0].shape[0])
|
||||
transcription_text.append(tokens[1])
|
||||
transcription_text_lens.append(tokens[1].shape[0])
|
||||
|
||||
padded_audios = pad_sequence(
|
||||
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
|
||||
audio_lens = np.array(audio_lens).astype(np.int64)
|
||||
padded_translation = pad_sequence(
|
||||
translation_text, padding_value=IGNORE_ID).astype(np.int64)
|
||||
translation_lens = np.array(translation_text_lens).astype(np.int64)
|
||||
padded_transcription = pad_sequence(
|
||||
transcription_text, padding_value=IGNORE_ID).astype(np.int64)
|
||||
transcription_lens = np.array(transcription_text_lens).astype(np.int64)
|
||||
return utts, padded_audios, audio_lens, (
|
||||
padded_translation, padded_transcription), (translation_lens,
|
||||
transcription_lens)
|
@ -0,0 +1,734 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""U2 ASR Model
|
||||
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
|
||||
(https://arxiv.org/pdf/2012.05481.pdf)
|
||||
"""
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
from paddle import jit
|
||||
from paddle import nn
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from deepspeech.frontend.utility import IGNORE_ID
|
||||
from deepspeech.frontend.utility import load_cmvn
|
||||
from deepspeech.modules.cmvn import GlobalCMVN
|
||||
from deepspeech.modules.ctc import CTCDecoder
|
||||
from deepspeech.modules.decoder import TransformerDecoder
|
||||
from deepspeech.modules.encoder import ConformerEncoder
|
||||
from deepspeech.modules.encoder import TransformerEncoder
|
||||
from deepspeech.modules.loss import LabelSmoothingLoss
|
||||
from deepspeech.modules.mask import make_pad_mask
|
||||
from deepspeech.modules.mask import mask_finished_preds
|
||||
from deepspeech.modules.mask import mask_finished_scores
|
||||
from deepspeech.modules.mask import subsequent_mask
|
||||
from deepspeech.utils import checkpoint
|
||||
from deepspeech.utils import layer_tools
|
||||
from deepspeech.utils.ctc_utils import remove_duplicates_and_blank
|
||||
from deepspeech.utils.log import Log
|
||||
from deepspeech.utils.tensor_utils import add_sos_eos
|
||||
from deepspeech.utils.tensor_utils import pad_sequence
|
||||
from deepspeech.utils.tensor_utils import th_accuracy
|
||||
from deepspeech.utils.utility import log_add
|
||||
|
||||
__all__ = ["U2STModel", "U2STInferModel"]
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class U2STBaseModel(nn.Module):
|
||||
"""CTC-Attention hybrid Encoder-Decoder model"""
|
||||
|
||||
@classmethod
|
||||
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
||||
# network architecture
|
||||
default = CfgNode()
|
||||
# allow add new item when merge_with_file
|
||||
default.cmvn_file = ""
|
||||
default.cmvn_file_type = "json"
|
||||
default.input_dim = 0
|
||||
default.output_dim = 0
|
||||
# encoder related
|
||||
default.encoder = 'transformer'
|
||||
default.encoder_conf = CfgNode(
|
||||
dict(
|
||||
output_size=256, # dimension of attention
|
||||
attention_heads=4,
|
||||
linear_units=2048, # the number of units of position-wise feed forward
|
||||
num_blocks=12, # the number of encoder blocks
|
||||
dropout_rate=0.1,
|
||||
positional_dropout_rate=0.1,
|
||||
attention_dropout_rate=0.0,
|
||||
input_layer='conv2d', # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||
normalize_before=True,
|
||||
# use_cnn_module=True,
|
||||
# cnn_module_kernel=15,
|
||||
# activation_type='swish',
|
||||
# pos_enc_layer_type='rel_pos',
|
||||
# selfattention_layer_type='rel_selfattn',
|
||||
))
|
||||
# decoder related
|
||||
default.decoder = 'transformer'
|
||||
default.decoder_conf = CfgNode(
|
||||
dict(
|
||||
attention_heads=4,
|
||||
linear_units=2048,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.1,
|
||||
positional_dropout_rate=0.1,
|
||||
self_attention_dropout_rate=0.0,
|
||||
src_attention_dropout_rate=0.0, ))
|
||||
# hybrid CTC/attention
|
||||
default.model_conf = CfgNode(
|
||||
dict(
|
||||
asr_weight=0.0,
|
||||
ctc_weight=0.0,
|
||||
lsm_weight=0.1, # label smoothing option
|
||||
length_normalized_loss=False, ))
|
||||
|
||||
if config is not None:
|
||||
config.merge_from_other_cfg(default)
|
||||
return default
|
||||
|
||||
def __init__(self,
|
||||
vocab_size: int,
|
||||
encoder: TransformerEncoder,
|
||||
st_decoder: TransformerDecoder,
|
||||
decoder: TransformerDecoder=None,
|
||||
ctc: CTCDecoder=None,
|
||||
ctc_weight: float=0.0,
|
||||
asr_weight: float=0.0,
|
||||
ignore_id: int=IGNORE_ID,
|
||||
lsm_weight: float=0.0,
|
||||
length_normalized_loss: bool=False):
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
|
||||
super().__init__()
|
||||
# note that eos is the same as sos (equivalent ID)
|
||||
self.sos = vocab_size - 1
|
||||
self.eos = vocab_size - 1
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.ctc_weight = ctc_weight
|
||||
self.asr_weight = asr_weight
|
||||
|
||||
self.encoder = encoder
|
||||
self.st_decoder = st_decoder
|
||||
self.decoder = decoder
|
||||
self.ctc = ctc
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss, )
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: paddle.Tensor,
|
||||
speech_lengths: paddle.Tensor,
|
||||
text: paddle.Tensor,
|
||||
text_lengths: paddle.Tensor,
|
||||
asr_text: paddle.Tensor=None,
|
||||
asr_text_lengths: paddle.Tensor=None,
|
||||
) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[
|
||||
paddle.Tensor]]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
Returns:
|
||||
total_loss, attention_loss, ctc_loss
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
|
||||
text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
|
||||
text.shape, text_lengths.shape)
|
||||
# 1. Encoder
|
||||
start = time.time()
|
||||
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
|
||||
encoder_time = time.time() - start
|
||||
#logger.debug(f"encoder time: {encoder_time}")
|
||||
#TODO(Hui Zhang): sum not support bool type
|
||||
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
|
||||
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
|
||||
1) #[B, 1, T] -> [B]
|
||||
|
||||
# 2a. ST-decoder branch
|
||||
start = time.time()
|
||||
loss_st, acc_st = self._calc_st_loss(encoder_out, encoder_mask, text,
|
||||
text_lengths)
|
||||
decoder_time = time.time() - start
|
||||
|
||||
loss_asr_att = None
|
||||
loss_asr_ctc = None
|
||||
# 2b. ASR Attention-decoder branch
|
||||
if self.asr_weight > 0.:
|
||||
if self.ctc_weight != 1.0:
|
||||
start = time.time()
|
||||
loss_asr_att, acc_att = self._calc_att_loss(
|
||||
encoder_out, encoder_mask, asr_text, asr_text_lengths)
|
||||
decoder_time = time.time() - start
|
||||
|
||||
# 2c. CTC branch
|
||||
if self.ctc_weight != 0.0:
|
||||
start = time.time()
|
||||
loss_asr_ctc = self.ctc(encoder_out, encoder_out_lens, asr_text,
|
||||
asr_text_lengths)
|
||||
ctc_time = time.time() - start
|
||||
|
||||
if loss_asr_ctc is None:
|
||||
loss_asr = loss_asr_att
|
||||
elif loss_asr_att is None:
|
||||
loss_asr = loss_asr_ctc
|
||||
else:
|
||||
loss_asr = self.ctc_weight * loss_asr_ctc + (1 - self.ctc_weight
|
||||
) * loss_asr_att
|
||||
loss = self.asr_weight * loss_asr + (1 - self.asr_weight) * loss_st
|
||||
else:
|
||||
loss = loss_st
|
||||
return loss, loss_st, loss_asr_att, loss_asr_ctc
|
||||
|
||||
def _calc_st_loss(
|
||||
self,
|
||||
encoder_out: paddle.Tensor,
|
||||
encoder_mask: paddle.Tensor,
|
||||
ys_pad: paddle.Tensor,
|
||||
ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]:
|
||||
"""Calc attention loss.
|
||||
|
||||
Args:
|
||||
encoder_out (paddle.Tensor): [B, Tmax, D]
|
||||
encoder_mask (paddle.Tensor): [B, 1, Tmax]
|
||||
ys_pad (paddle.Tensor): [B, Umax]
|
||||
ys_pad_lens (paddle.Tensor): [B]
|
||||
|
||||
Returns:
|
||||
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
|
||||
"""
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
|
||||
self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad,
|
||||
ys_in_lens)
|
||||
|
||||
# 2. Compute attention loss
|
||||
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
||||
acc_att = th_accuracy(
|
||||
decoder_out.view(-1, self.vocab_size),
|
||||
ys_out_pad,
|
||||
ignore_label=self.ignore_id, )
|
||||
return loss_att, acc_att
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
encoder_out: paddle.Tensor,
|
||||
encoder_mask: paddle.Tensor,
|
||||
ys_pad: paddle.Tensor,
|
||||
ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]:
|
||||
"""Calc attention loss.
|
||||
|
||||
Args:
|
||||
encoder_out (paddle.Tensor): [B, Tmax, D]
|
||||
encoder_mask (paddle.Tensor): [B, 1, Tmax]
|
||||
ys_pad (paddle.Tensor): [B, Umax]
|
||||
ys_pad_lens (paddle.Tensor): [B]
|
||||
|
||||
Returns:
|
||||
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
|
||||
"""
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
|
||||
self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
|
||||
ys_in_lens)
|
||||
|
||||
# 2. Compute attention loss
|
||||
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
||||
acc_att = th_accuracy(
|
||||
decoder_out.view(-1, self.vocab_size),
|
||||
ys_out_pad,
|
||||
ignore_label=self.ignore_id, )
|
||||
return loss_att, acc_att
|
||||
|
||||
def _forward_encoder(
|
||||
self,
|
||||
speech: paddle.Tensor,
|
||||
speech_lengths: paddle.Tensor,
|
||||
decoding_chunk_size: int=-1,
|
||||
num_decoding_left_chunks: int=-1,
|
||||
simulate_streaming: bool=False,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Encoder pass.
|
||||
|
||||
Args:
|
||||
speech (paddle.Tensor): [B, Tmax, D]
|
||||
speech_lengths (paddle.Tensor): [B]
|
||||
decoding_chunk_size (int, optional): chuck size. Defaults to -1.
|
||||
num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1.
|
||||
simulate_streaming (bool, optional): streaming or not. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
encoder hiddens (B, Tmax, D),
|
||||
encoder hiddens mask (B, 1, Tmax).
|
||||
"""
|
||||
# Let's assume B = batch_size
|
||||
# 1. Encoder
|
||||
if simulate_streaming and decoding_chunk_size > 0:
|
||||
encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk(
|
||||
speech,
|
||||
decoding_chunk_size=decoding_chunk_size,
|
||||
num_decoding_left_chunks=num_decoding_left_chunks
|
||||
) # (B, maxlen, encoder_dim)
|
||||
else:
|
||||
encoder_out, encoder_mask = self.encoder(
|
||||
speech,
|
||||
speech_lengths,
|
||||
decoding_chunk_size=decoding_chunk_size,
|
||||
num_decoding_left_chunks=num_decoding_left_chunks
|
||||
) # (B, maxlen, encoder_dim)
|
||||
return encoder_out, encoder_mask
|
||||
|
||||
def translate(
|
||||
self,
|
||||
speech: paddle.Tensor,
|
||||
speech_lengths: paddle.Tensor,
|
||||
beam_size: int=10,
|
||||
decoding_chunk_size: int=-1,
|
||||
num_decoding_left_chunks: int=-1,
|
||||
simulate_streaming: bool=False, ) -> paddle.Tensor:
|
||||
""" Apply beam search on attention decoder
|
||||
Args:
|
||||
speech (paddle.Tensor): (batch, max_len, feat_dim)
|
||||
speech_length (paddle.Tensor): (batch, )
|
||||
beam_size (int): beam size for beam search
|
||||
decoding_chunk_size (int): decoding chunk for dynamic chunk
|
||||
trained model.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
0: used for training, it's prohibited here
|
||||
simulate_streaming (bool): whether do encoder forward in a
|
||||
streaming fashion
|
||||
Returns:
|
||||
paddle.Tensor: decoding result, (batch, max_result_len)
|
||||
"""
|
||||
assert speech.shape[0] == speech_lengths.shape[0]
|
||||
assert decoding_chunk_size != 0
|
||||
device = speech.place
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
# Let's assume B = batch_size and N = beam_size
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_mask = self._forward_encoder(
|
||||
speech, speech_lengths, decoding_chunk_size,
|
||||
num_decoding_left_chunks,
|
||||
simulate_streaming) # (B, maxlen, encoder_dim)
|
||||
maxlen = encoder_out.size(1)
|
||||
encoder_dim = encoder_out.size(2)
|
||||
running_size = batch_size * beam_size
|
||||
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
|
||||
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
|
||||
encoder_mask = encoder_mask.unsqueeze(1).repeat(
|
||||
1, beam_size, 1, 1).view(running_size, 1,
|
||||
maxlen) # (B*N, 1, max_len)
|
||||
|
||||
hyps = paddle.ones(
|
||||
[running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1)
|
||||
# log scale score
|
||||
scores = paddle.to_tensor(
|
||||
[0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float)
|
||||
scores = scores.to(device).repeat(batch_size).unsqueeze(1).to(
|
||||
device) # (B*N, 1)
|
||||
end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1)
|
||||
cache: Optional[List[paddle.Tensor]] = None
|
||||
# 2. Decoder forward step by step
|
||||
for i in range(1, maxlen + 1):
|
||||
# Stop if all batch and all beam produce eos
|
||||
# TODO(Hui Zhang): if end_flag.sum() == running_size:
|
||||
if end_flag.cast(paddle.int64).sum() == running_size:
|
||||
break
|
||||
|
||||
# 2.1 Forward decoder step
|
||||
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
|
||||
running_size, 1, 1).to(device) # (B*N, i, i)
|
||||
# logp: (B*N, vocab)
|
||||
logp, cache = self.st_decoder.forward_one_step(
|
||||
encoder_out, encoder_mask, hyps, hyps_mask, cache)
|
||||
|
||||
# 2.2 First beam prune: select topk best prob at current time
|
||||
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
|
||||
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
|
||||
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
|
||||
|
||||
# 2.3 Seconde beam prune: select topk score with history
|
||||
scores = scores + top_k_logp # (B*N, N), broadcast add
|
||||
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
|
||||
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
|
||||
scores = scores.view(-1, 1) # (B*N, 1)
|
||||
|
||||
# 2.4. Compute base index in top_k_index,
|
||||
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
|
||||
# then find offset_k_index in top_k_index
|
||||
base_k_index = paddle.arange(batch_size).view(-1, 1).repeat(
|
||||
1, beam_size) # (B, N)
|
||||
base_k_index = base_k_index * beam_size * beam_size
|
||||
best_k_index = base_k_index.view(-1) + offset_k_index.view(
|
||||
-1) # (B*N)
|
||||
|
||||
# 2.5 Update best hyps
|
||||
best_k_pred = paddle.index_select(
|
||||
top_k_index.view(-1), index=best_k_index, axis=0) # (B*N)
|
||||
best_hyps_index = best_k_index // beam_size
|
||||
last_best_k_hyps = paddle.index_select(
|
||||
hyps, index=best_hyps_index, axis=0) # (B*N, i)
|
||||
hyps = paddle.cat(
|
||||
(last_best_k_hyps, best_k_pred.view(-1, 1)),
|
||||
dim=1) # (B*N, i+1)
|
||||
|
||||
# 2.6 Update end flag
|
||||
end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1)
|
||||
|
||||
# 3. Select best of best
|
||||
scores = scores.view(batch_size, beam_size)
|
||||
# TODO: length normalization
|
||||
best_index = paddle.argmax(scores, axis=-1).long() # (B)
|
||||
best_hyps_index = best_index + paddle.arange(
|
||||
batch_size, dtype=paddle.long) * beam_size
|
||||
best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0)
|
||||
best_hyps = best_hyps[:, 1:]
|
||||
return best_hyps
|
||||
|
||||
@jit.export
|
||||
def subsampling_rate(self) -> int:
|
||||
""" Export interface for c++ call, return subsampling_rate of the
|
||||
model
|
||||
"""
|
||||
return self.encoder.embed.subsampling_rate
|
||||
|
||||
@jit.export
|
||||
def right_context(self) -> int:
|
||||
""" Export interface for c++ call, return right_context of the model
|
||||
"""
|
||||
return self.encoder.embed.right_context
|
||||
|
||||
@jit.export
|
||||
def sos_symbol(self) -> int:
|
||||
""" Export interface for c++ call, return sos symbol id of the model
|
||||
"""
|
||||
return self.sos
|
||||
|
||||
@jit.export
|
||||
def eos_symbol(self) -> int:
|
||||
""" Export interface for c++ call, return eos symbol id of the model
|
||||
"""
|
||||
return self.eos
|
||||
|
||||
@jit.export
|
||||
def forward_encoder_chunk(
|
||||
self,
|
||||
xs: paddle.Tensor,
|
||||
offset: int,
|
||||
required_cache_size: int,
|
||||
subsampling_cache: Optional[paddle.Tensor]=None,
|
||||
elayers_output_cache: Optional[List[paddle.Tensor]]=None,
|
||||
conformer_cnn_cache: Optional[List[paddle.Tensor]]=None,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[
|
||||
paddle.Tensor]]:
|
||||
""" Export interface for c++ call, give input chunk xs, and return
|
||||
output from time 0 to current chunk.
|
||||
Args:
|
||||
xs (paddle.Tensor): chunk input
|
||||
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
|
||||
elayers_output_cache (Optional[List[paddle.Tensor]]):
|
||||
transformer/conformer encoder layers output cache
|
||||
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
|
||||
cnn cache
|
||||
Returns:
|
||||
paddle.Tensor: output, it ranges from time 0 to current chunk.
|
||||
paddle.Tensor: subsampling cache
|
||||
List[paddle.Tensor]: attention cache
|
||||
List[paddle.Tensor]: conformer cnn cache
|
||||
"""
|
||||
return self.encoder.forward_chunk(
|
||||
xs, offset, required_cache_size, subsampling_cache,
|
||||
elayers_output_cache, conformer_cnn_cache)
|
||||
|
||||
@jit.export
|
||||
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
|
||||
""" Export interface for c++ call, apply linear transform and log
|
||||
softmax before ctc
|
||||
Args:
|
||||
xs (paddle.Tensor): encoder output
|
||||
Returns:
|
||||
paddle.Tensor: activation before ctc
|
||||
"""
|
||||
return self.ctc.log_softmax(xs)
|
||||
|
||||
@jit.export
|
||||
def forward_attention_decoder(
|
||||
self,
|
||||
hyps: paddle.Tensor,
|
||||
hyps_lens: paddle.Tensor,
|
||||
encoder_out: paddle.Tensor, ) -> paddle.Tensor:
|
||||
""" Export interface for c++ call, forward decoder with multiple
|
||||
hypothesis from ctc prefix beam search and one encoder output
|
||||
Args:
|
||||
hyps (paddle.Tensor): hyps from ctc prefix beam search, already
|
||||
pad sos at the begining, (B, T)
|
||||
hyps_lens (paddle.Tensor): length of each hyp in hyps, (B)
|
||||
encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D)
|
||||
Returns:
|
||||
paddle.Tensor: decoder output, (B, L)
|
||||
"""
|
||||
assert encoder_out.size(0) == 1
|
||||
num_hyps = hyps.size(0)
|
||||
assert hyps_lens.size(0) == num_hyps
|
||||
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
|
||||
# (B, 1, T)
|
||||
encoder_mask = paddle.ones(
|
||||
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool)
|
||||
# (num_hyps, max_hyps_len, vocab_size)
|
||||
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
|
||||
hyps_lens)
|
||||
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1)
|
||||
return decoder_out
|
||||
|
||||
@paddle.no_grad()
|
||||
def decode(self,
|
||||
feats: paddle.Tensor,
|
||||
feats_lengths: paddle.Tensor,
|
||||
text_feature: Dict[str, int],
|
||||
decoding_method: str,
|
||||
lang_model_path: str,
|
||||
beam_alpha: float,
|
||||
beam_beta: float,
|
||||
beam_size: int,
|
||||
cutoff_prob: float,
|
||||
cutoff_top_n: int,
|
||||
num_processes: int,
|
||||
ctc_weight: float=0.0,
|
||||
decoding_chunk_size: int=-1,
|
||||
num_decoding_left_chunks: int=-1,
|
||||
simulate_streaming: bool=False):
|
||||
"""u2 decoding.
|
||||
|
||||
Args:
|
||||
feats (Tenosr): audio features, (B, T, D)
|
||||
feats_lengths (Tenosr): (B)
|
||||
text_feature (TextFeaturizer): text feature object.
|
||||
decoding_method (str): decoding mode, e.g.
|
||||
'fullsentence',
|
||||
'simultaneous'
|
||||
lang_model_path (str): lm path.
|
||||
beam_alpha (float): lm weight.
|
||||
beam_beta (float): length penalty.
|
||||
beam_size (int): beam size for search
|
||||
cutoff_prob (float): for prune.
|
||||
cutoff_top_n (int): for prune.
|
||||
num_processes (int):
|
||||
ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0.
|
||||
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
0: used for training, it's prohibited here.
|
||||
num_decoding_left_chunks (int, optional):
|
||||
number of left chunks for decoding. Defaults to -1.
|
||||
simulate_streaming (bool, optional): simulate streaming inference. Defaults to False.
|
||||
|
||||
Raises:
|
||||
ValueError: when not support decoding_method.
|
||||
|
||||
Returns:
|
||||
List[List[int]]: transcripts.
|
||||
"""
|
||||
batch_size = feats.size(0)
|
||||
|
||||
if decoding_method == 'fullsentence':
|
||||
hyps = self.translate(
|
||||
feats,
|
||||
feats_lengths,
|
||||
beam_size=beam_size,
|
||||
decoding_chunk_size=decoding_chunk_size,
|
||||
num_decoding_left_chunks=num_decoding_left_chunks,
|
||||
simulate_streaming=simulate_streaming)
|
||||
hyps = [hyp.tolist() for hyp in hyps]
|
||||
else:
|
||||
raise ValueError(f"Not support decoding method: {decoding_method}")
|
||||
|
||||
res = [text_feature.defeaturize(hyp) for hyp in hyps]
|
||||
return res
|
||||
|
||||
|
||||
class U2STModel(U2STBaseModel):
|
||||
def __init__(self, configs: dict):
|
||||
vocab_size, encoder, decoder = U2STModel._init_from_config(configs)
|
||||
|
||||
if isinstance(decoder, Tuple):
|
||||
st_decoder, asr_decoder, ctc = decoder
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder=encoder,
|
||||
st_decoder=st_decoder,
|
||||
decoder=asr_decoder,
|
||||
ctc=ctc,
|
||||
**configs['model_conf'])
|
||||
else:
|
||||
super().__init__(
|
||||
vocab_size=vocab_size,
|
||||
encoder=encoder,
|
||||
st_decoder=decoder,
|
||||
**configs['model_conf'])
|
||||
|
||||
@classmethod
|
||||
def _init_from_config(cls, configs: dict):
|
||||
"""init sub module for model.
|
||||
|
||||
Args:
|
||||
configs (dict): config dict.
|
||||
|
||||
Raises:
|
||||
ValueError: raise when using not support encoder type.
|
||||
|
||||
Returns:
|
||||
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
|
||||
"""
|
||||
if configs['cmvn_file'] is not None:
|
||||
mean, istd = load_cmvn(configs['cmvn_file'],
|
||||
configs['cmvn_file_type'])
|
||||
global_cmvn = GlobalCMVN(
|
||||
paddle.to_tensor(mean, dtype=paddle.float),
|
||||
paddle.to_tensor(istd, dtype=paddle.float))
|
||||
else:
|
||||
global_cmvn = None
|
||||
|
||||
input_dim = configs['input_dim']
|
||||
vocab_size = configs['output_dim']
|
||||
assert input_dim != 0, input_dim
|
||||
assert vocab_size != 0, vocab_size
|
||||
|
||||
encoder_type = configs.get('encoder', 'transformer')
|
||||
logger.info(f"U2 Encoder type: {encoder_type}")
|
||||
if encoder_type == 'transformer':
|
||||
encoder = TransformerEncoder(
|
||||
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
|
||||
elif encoder_type == 'conformer':
|
||||
encoder = ConformerEncoder(
|
||||
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
|
||||
else:
|
||||
raise ValueError(f"not support encoder type:{encoder_type}")
|
||||
|
||||
st_decoder = TransformerDecoder(vocab_size,
|
||||
encoder.output_size(),
|
||||
**configs['decoder_conf'])
|
||||
|
||||
asr_weight = configs['model_conf']['asr_weight']
|
||||
logger.info(f"ASR Joint Training Weight: {asr_weight}")
|
||||
|
||||
if asr_weight > 0.:
|
||||
decoder = TransformerDecoder(vocab_size,
|
||||
encoder.output_size(),
|
||||
**configs['decoder_conf'])
|
||||
ctc = CTCDecoder(
|
||||
odim=vocab_size,
|
||||
enc_n_units=encoder.output_size(),
|
||||
blank_id=0,
|
||||
dropout_rate=0.0,
|
||||
reduction=True, # sum
|
||||
batch_average=True) # sum / batch_size
|
||||
|
||||
return vocab_size, encoder, (st_decoder, decoder, ctc)
|
||||
else:
|
||||
return vocab_size, encoder, st_decoder
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, configs: dict):
|
||||
"""init model.
|
||||
|
||||
Args:
|
||||
configs (dict): config dict.
|
||||
|
||||
Raises:
|
||||
ValueError: raise when using not support encoder type.
|
||||
|
||||
Returns:
|
||||
nn.Layer: U2STModel
|
||||
"""
|
||||
model = cls(configs)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, dataloader, config, checkpoint_path):
|
||||
"""Build a DeepSpeech2Model model from a pretrained model.
|
||||
|
||||
Args:
|
||||
dataloader (paddle.io.DataLoader): not used.
|
||||
config (yacs.config.CfgNode): model configs
|
||||
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
|
||||
|
||||
Returns:
|
||||
DeepSpeech2Model: The model built from pretrained result.
|
||||
"""
|
||||
config.defrost()
|
||||
config.input_dim = dataloader.collate_fn.feature_size
|
||||
config.output_dim = dataloader.collate_fn.vocab_size
|
||||
config.freeze()
|
||||
model = cls.from_config(config)
|
||||
|
||||
if checkpoint_path:
|
||||
infos = checkpoint.load_parameters(
|
||||
model, checkpoint_path=checkpoint_path)
|
||||
logger.info(f"checkpoint info: {infos}")
|
||||
layer_tools.summary(model)
|
||||
return model
|
||||
|
||||
|
||||
class U2STInferModel(U2STModel):
|
||||
def __init__(self, configs: dict):
|
||||
super().__init__(configs)
|
||||
|
||||
def forward(self,
|
||||
feats,
|
||||
feats_lengths,
|
||||
decoding_chunk_size=-1,
|
||||
num_decoding_left_chunks=-1,
|
||||
simulate_streaming=False):
|
||||
"""export model function
|
||||
|
||||
Args:
|
||||
feats (Tensor): [B, T, D]
|
||||
feats_lengths (Tensor): [B]
|
||||
|
||||
Returns:
|
||||
List[List[int]]: best path result
|
||||
"""
|
||||
return self.translate(
|
||||
feats,
|
||||
feats_lengths,
|
||||
decoding_chunk_size=decoding_chunk_size,
|
||||
num_decoding_left_chunks=num_decoding_left_chunks,
|
||||
simulate_streaming=simulate_streaming)
|
@ -0,0 +1,53 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""This module provides functions to calculate bleu score in different level.
|
||||
e.g. wer for word-level, cer for char-level.
|
||||
"""
|
||||
import numpy as np
|
||||
import sacrebleu
|
||||
|
||||
__all__ = ['bleu', 'char_bleu']
|
||||
|
||||
|
||||
def bleu(hypothesis, reference):
|
||||
"""Calculate BLEU. BLEU compares reference text and
|
||||
hypothesis text in word-level using scarebleu.
|
||||
|
||||
|
||||
|
||||
:param reference: The reference sentences.
|
||||
:type reference: list[list[str]]
|
||||
:param hypothesis: The hypothesis sentence.
|
||||
:type hypothesis: list[str]
|
||||
:raises ValueError: If the reference length is zero.
|
||||
"""
|
||||
|
||||
return sacrebleu.corpus_bleu(hypothesis, reference)
|
||||
|
||||
def char_bleu(hypothesis, reference):
|
||||
"""Calculate BLEU. BLEU compares reference text and
|
||||
hypothesis text in char-level using scarebleu.
|
||||
|
||||
|
||||
|
||||
:param reference: The reference sentences.
|
||||
:type reference: list[list[str]]
|
||||
:param hypothesis: The hypothesis sentence.
|
||||
:type hypothesis: list[str]
|
||||
:raises ValueError: If the reference number is zero.
|
||||
"""
|
||||
hypothesis =[' '.join(list(hyp.replace(' ', ''))) for hyp in hypothesis]
|
||||
reference = [[' '.join(list(ref_i.replace(' ', ''))) for ref_i in ref ]for ref in reference ]
|
||||
|
||||
return sacrebleu.corpus_bleu(hypothesis, reference)
|
Loading…
Reference in new issue