diff --git a/deepspeech/frontend/augmentor/base.py b/deepspeech/frontend/augmentor/base.py index 0f7826cdf..3bc37e68a 100644 --- a/deepspeech/frontend/augmentor/base.py +++ b/deepspeech/frontend/augmentor/base.py @@ -41,3 +41,17 @@ class AugmentorBase(): :type audio_segment: AudioSegmenet|SpeechSegment """ pass + + @abstractmethod + def transform_spectrogram(self, spec_segment): + """Adds various effects to the input spectrogram segment. Such effects + will augment the training data to make the model invariant to certain + types of time_mask or freq_mask in the real world, improving model's + generalization ability. + + Note that this is an in-place transformation. + + :param spec_segment: Spectrogram segment to add effects to. + :type spec_segment: Spectrogram + """ + pass \ No newline at end of file diff --git a/deepspeech/frontend/featurizer/audio_featurizer.py b/deepspeech/frontend/featurizer/audio_featurizer.py index 799525e55..6aef47622 100644 --- a/deepspeech/frontend/featurizer/audio_featurizer.py +++ b/deepspeech/frontend/featurizer/audio_featurizer.py @@ -17,6 +17,7 @@ import numpy as np from deepspeech.frontend.utility import read_manifest from deepspeech.frontend.audio import AudioSegment from python_speech_features import mfcc +from python_speech_features import logfbank from python_speech_features import delta @@ -49,7 +50,9 @@ class AudioFeaturizer(object): """ def __init__(self, - specgram_type='linear', + specgram_type: str='linear', + feat_dim: int=None, + delta_delta: bool=False, stride_ms=10.0, window_ms=20.0, n_fft=None, @@ -58,6 +61,8 @@ class AudioFeaturizer(object): use_dB_normalization=True, target_dB=-20): self._specgram_type = specgram_type + self._feat_dim = feat_dim + self._delta_delta = delta_delta self._stride_ms = stride_ms self._window_ms = window_ms self._max_freq = max_freq @@ -110,7 +115,12 @@ class AudioFeaturizer(object): 1) elif self._specgram_type == 'mfcc': # mfcc, delta, delta-delta - feat_dim = int(13 * 3) + feat_dim = int(self._feat_dim * + 3) if self._delta_delta else int(self._feat_dim) + elif self._specgram_type == 'fbank': + # fbank, delta, delta-delta + feat_dim = int(self._feat_dim * + 3) if self._delta_delta else int(self._feat_dim) else: raise ValueError("Unknown specgram_type %s. " "Supported values: linear." % self._specgram_type) @@ -123,8 +133,23 @@ class AudioFeaturizer(object): samples, sample_rate, self._stride_ms, self._window_ms, self._max_freq) elif self._specgram_type == 'mfcc': - return self._compute_mfcc(samples, sample_rate, self._stride_ms, - self._window_ms, self._max_freq) + return self._compute_mfcc( + samples, + sample_rate, + self._stride_ms, + self._feat_dim, + self._window_ms, + self._max_freq, + delta_delta=self._delta_delta) + elif self._specgram_type == 'fbank': + return self._compute_fbank( + samples, + sample_rate, + self._stride_ms, + self._feat_dim, + self._window_ms, + self._max_freq, + delta_delta=self._delta_delta) else: raise ValueError("Unknown specgram_type %s. " "Supported values: linear." % self._specgram_type) @@ -179,13 +204,54 @@ class AudioFeaturizer(object): freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) return fft, freqs + def _concat_delta_delta(self, feat): + """append delat, delta-delta feature. + + Args: + feat (np.ndarray): (D, T) + + Returns: + np.ndarray: feat with delta-delta, (3*D, T) + """ + feat = np.transpose(feat) + # Deltas + d_feat = delta(feat, 2) + # Deltas-Deltas + dd_feat = delta(feat, 2) + # transpose + feat = np.transpose(feat) + d_feat = np.transpose(d_feat) + dd_feat = np.transpose(dd_feat) + # concat above three features + concat_feat = np.concatenate((feat, d_feat, dd_feat)) + return concat_feat + def _compute_mfcc(self, samples, sample_rate, + feat_dim=13, stride_ms=10.0, window_ms=20.0, - max_freq=None): - """Compute mfcc from samples.""" + max_freq=None, + delta_delta=True): + """Compute mfcc from samples. + + Args: + samples (np.ndarray): the audio signal from which to compute features. Should be an N*1 array + sample_rate (float): the sample rate of the signal we are working with, in Hz. + feat_dim (int): the number of cepstrum to return, default 13. + stride_ms (float, optional): stride length in ms. Defaults to 10.0. + window_ms (float, optional): window length in ms. Defaults to 20.0. + max_freq ([type], optional): highest band edge of mel filters. In Hz, default is samplerate/2. Defaults to None. + delta_delta (bool, optional): Whether with delta delta. Defaults to False. + + Raises: + ValueError: max_freq > samplerate/2 + ValueError: stride_ms > window_ms + + Returns: + np.ndarray: mfcc feature, (D, T). + """ if max_freq is None: max_freq = sample_rate / 2 if max_freq > sample_rate / 2: @@ -195,22 +261,73 @@ class AudioFeaturizer(object): raise ValueError("Stride size must not be greater than " "window size.") # compute the 13 cepstral coefficients, and the first one is replaced - # by log(frame energy) + # by log(frame energy), (T, D) mfcc_feat = mfcc( signal=samples, samplerate=sample_rate, winlen=0.001 * window_ms, winstep=0.001 * stride_ms, - highfreq=max_freq) - # Deltas - d_mfcc_feat = delta(mfcc_feat, 2) - # Deltas-Deltas - dd_mfcc_feat = delta(d_mfcc_feat, 2) - # transpose + numcep=feat_dim, + nfilt=2 * feat_dim, + nfft=None, + lowfreq=0, + highfreq=max_freq, + preemph=0.97, + ceplifter=22, + appendEnergy=True, + winfunc=lambda x: np.ones((x, ))) mfcc_feat = np.transpose(mfcc_feat) - d_mfcc_feat = np.transpose(d_mfcc_feat) - dd_mfcc_feat = np.transpose(dd_mfcc_feat) - # concat above three features - concat_mfcc_feat = np.concatenate( - (mfcc_feat, d_mfcc_feat, dd_mfcc_feat)) - return concat_mfcc_feat + if delta_delta: + mfcc_feat = self._concat_delta_delta(mfcc_feat) + return mfcc_feat + + def _compute_fbank(self, + samples, + sample_rate, + feat_dim=26, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + delta_delta=False): + """Compute logfbank from samples. + + Args: + samples (np.ndarray): the audio signal from which to compute features. Should be an N*1 array + sample_rate (float): the sample rate of the signal we are working with, in Hz. + feat_dim (int): the number of cepstrum to return, default 13. + stride_ms (float, optional): stride length in ms. Defaults to 10.0. + window_ms (float, optional): window length in ms. Defaults to 20.0. + max_freq (float, optional): highest band edge of mel filters. In Hz, default is samplerate/2. Defaults to None. + delta_delta (bool, optional): Whether with delta delta. Defaults to False. + + Raises: + ValueError: max_freq > samplerate/2 + ValueError: stride_ms > window_ms + + Returns: + np.ndarray: mfcc feature, (D, T). + """ + if max_freq is None: + max_freq = sample_rate / 2 + if max_freq > sample_rate / 2: + raise ValueError("max_freq must not be greater than half of " + "sample rate.") + if stride_ms > window_ms: + raise ValueError("Stride size must not be greater than " + "window size.") + #(T, D) + fbank_feat = logfbank( + signal=samples, + samplerate=sample_rate, + winlen=0.001 * window_ms, + winstep=0.001 * stride_ms, + nfilt=feat_dim, + nfft=512, + lowfreq=max_freq, + highfreq=None, + preemph=0.97, + winfunc=lambda x: np.ones((x, ))) + fbank_feat = np.transpose(fbank_feat) + if delta_delta: + fbank_feat = self._concat_delta_delta(fbank_feat) + return fbank_feat diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index 894c684bf..e8f92798b 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -56,6 +56,8 @@ class SpeechFeaturizer(object): vocab_filepath, spm_model_prefix=None, specgram_type='linear', + feat_dim=13, + delta_delta=True, stride_ms=10.0, window_ms=20.0, n_fft=None, @@ -65,6 +67,8 @@ class SpeechFeaturizer(object): target_dB=-20): self._audio_featurizer = AudioFeaturizer( specgram_type=specgram_type, + feat_dim=feat_dim, + delta_delta=delta_delta, stride_ms=stride_ms, window_ms=window_ms, n_fft=n_fft, @@ -82,17 +86,22 @@ class SpeechFeaturizer(object): 2. For transcript parts, keep the original text or convert text string to a list of token indices in char-level. - :param audio_segment: Speech segment to extract features from. - :type audio_segment: SpeechSegment - :return: A tuple of 1) spectrogram audio feature in 2darray, 2) list of - char-level token indices. - :rtype: tuple + Args: + speech_segment (SpeechSegment): Speech segment to extract features from. + keep_transcription_text (bool): True, keep transcript text, False, token ids + + Returns: + tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices. """ - audio_feature = self._audio_featurizer.featurize(speech_segment) + spec_feature = self._audio_featurizer.featurize(speech_segment) if keep_transcription_text: - return audio_feature, speech_segment.transcript - text_ids = self._text_featurizer.featurize(speech_segment.transcript) - return audio_feature, text_ids + return spec_feature, speech_segment.transcript + if speech_segment.has_token: + text_ids = speech_segment.token_ids + else: + text_ids = self._text_featurizer.featurize( + speech_segment.transcript) + return spec_feature, text_ids @property def vocab_size(self): diff --git a/deepspeech/frontend/speech.py b/deepspeech/frontend/speech.py index 2883405bb..ec24e8a9e 100644 --- a/deepspeech/frontend/speech.py +++ b/deepspeech/frontend/speech.py @@ -24,7 +24,12 @@ class SpeechSegment(AudioSegment): AudioSegment (AudioSegment): Audio Segment """ - def __init__(self, samples, sample_rate, transcript): + def __init__(self, + samples, + sample_rate, + transcript, + tokens=None, + token_ids=None): """Speech segment abstraction, a subclass of AudioSegment, with an additional transcript. @@ -32,9 +37,13 @@ class SpeechSegment(AudioSegment): samples (ndarray.float32): Audio samples [num_samples x num_channels]. sample_rate (int): Audio sample rate. transcript (str): Transcript text for the speech. + tokens (List[str], optinal): Transcript tokens for the speech. + token_ids (List[int], optional): Transcript token ids for the speech. """ AudioSegment.__init__(self, samples, sample_rate) self._transcript = transcript + self._tokens = tokens + self._token_ids = token_ids def __eq__(self, other): """Return whether two objects are equal. @@ -46,6 +55,11 @@ class SpeechSegment(AudioSegment): return False if self._transcript != other._transcript: return False + if self.has_token and other.has_token: + if self._tokens != other._tokens: + return False + if self._token_ids != other._token_ids: + return False return True def __ne__(self, other): @@ -53,33 +67,39 @@ class SpeechSegment(AudioSegment): return not self.__eq__(other) @classmethod - def from_file(cls, filepath, transcript): + def from_file(cls, filepath, transcript, tokens=None, token_ids=None): """Create speech segment from audio file and corresponding transcript. - - :param filepath: Filepath or file object to audio file. - :type filepath: str|file - :param transcript: Transcript text for the speech. - :type transript: str - :return: Speech segment instance. - :rtype: SpeechSegment + + Args: + filepath (str|file): Filepath or file object to audio file. + transcript (str): Transcript text for the speech. + tokens (List[str], optional): text tokens. Defaults to None. + token_ids (List[int], optional): text token ids. Defaults to None. + + Returns: + SpeechSegment: Speech segment instance. """ + audio = AudioSegment.from_file(filepath) - return cls(audio.samples, audio.sample_rate, transcript) + return cls(audio.samples, audio.sample_rate, transcript, tokens, + token_ids) @classmethod - def from_bytes(cls, bytes, transcript): + def from_bytes(cls, bytes, transcript, tokens=None, token_ids=None): """Create speech segment from a byte string and corresponding - transcript. - - :param bytes: Byte string containing audio samples. - :type bytes: str - :param transcript: Transcript text for the speech. - :type transript: str - :return: Speech segment instance. - :rtype: Speech Segment + + Args: + filepath (str|file): Filepath or file object to audio file. + transcript (str): Transcript text for the speech. + tokens (List[str], optional): text tokens. Defaults to None. + token_ids (List[int], optional): text token ids. Defaults to None. + + Returns: + SpeechSegment: Speech segment instance. """ audio = AudioSegment.from_bytes(bytes) - return cls(audio.samples, audio.sample_rate, transcript) + return cls(audio.samples, audio.sample_rate, transcript, tokens, + token_ids) @classmethod def concatenate(cls, *segments): @@ -98,6 +118,8 @@ class SpeechSegment(AudioSegment): raise ValueError("No speech segments are given to concatenate.") sample_rate = segments[0]._sample_rate transcripts = "" + tokens = [] + token_ids = [] for seg in segments: if sample_rate != seg._sample_rate: raise ValueError("Can't concatenate segments with " @@ -106,11 +128,20 @@ class SpeechSegment(AudioSegment): raise TypeError("Only speech segments of the same type " "instance can be concatenated.") transcripts += seg._transcript + if self.has_token: + tokens += seg._tokens + token_ids += seg._token_ids samples = np.concatenate([seg.samples for seg in segments]) - return cls(samples, sample_rate, transcripts) + return cls(samples, sample_rate, transcripts, tokens, token_ids) @classmethod - def slice_from_file(cls, filepath, transcript, start=None, end=None): + def slice_from_file(cls, + filepath, + transcript, + tokens=None, + token_ids=None, + start=None, + end=None): """Loads a small section of an speech without having to load the entire file into the memory which can be incredibly wasteful. @@ -132,28 +163,54 @@ class SpeechSegment(AudioSegment): :rtype: SpeechSegment """ audio = AudioSegment.slice_from_file(filepath, start, end) - return cls(audio.samples, audio.sample_rate, transcript) + return cls(audio.samples, audio.sample_rate, transcript, tokens, + token_ids) @classmethod def make_silence(cls, duration, sample_rate): """Creates a silent speech segment of the given duration and sample rate, transcript will be an empty string. - :param duration: Length of silence in seconds. - :type duration: float - :param sample_rate: Sample rate. - :type sample_rate: float - :return: Silence of the given duration. - :rtype: SpeechSegment + Args: + duration (float): Length of silence in seconds. + sample_rate (float): Sample rate. + + Returns: + SpeechSegment: Silence of the given duration. """ audio = AudioSegment.make_silence(duration, sample_rate) return cls(audio.samples, audio.sample_rate, "") + @property + def has_token(self): + if self._tokens or self._token_ids: + return True + return False + @property def transcript(self): """Return the transcript text. - :return: Transcript text for the speech. - :rtype: str + Returns: + str: Transcript text for the speech. """ + return self._transcript + + @property + def tokens(self): + """Return the transcript text tokens. + + Returns: + List[str]: text tokens. + """ + return self._tokens + + @property + def token_ids(self): + """Return the transcript text token ids. + + Returns: + List[int]: text token ids. + """ + return self._token_ids diff --git a/utils/compute_mean_std.py b/utils/compute_mean_std.py index 80fe88813..339813748 100644 --- a/utils/compute_mean_std.py +++ b/utils/compute_mean_std.py @@ -26,8 +26,14 @@ add_arg = functools.partial(add_arguments, argparser=parser) add_arg('num_samples', int, 2000, "# of samples to for statistics.") add_arg('specgram_type', str, 'linear', - "Audio feature type. Options: linear, mfcc.", - choices=['linear', 'mfcc']) + "Audio feature type. Options: linear, mfcc, fbank.", + choices=['linear', 'mfcc', 'fbank']) +add_arg('feat_dim', int, + 13, + "Audio feature dim.") +add_arg('delta_delta', bool, + False, + "Audio feature with delta delta.") add_arg('manifest_path', str, 'data/librispeech/manifest.train', "Filepath of manifest to compute normalizer's mean and stddev.") @@ -42,7 +48,10 @@ def main(): print_arguments(args) augmentation_pipeline = AugmentationPipeline('{}') - audio_featurizer = AudioFeaturizer(specgram_type=args.specgram_type) + audio_featurizer = AudioFeaturizer( + specgram_type=args.specgram_type, + feat_dim=args.feat_dim, + delta_delta=args.delta_delta) def augment_and_featurize(audio_segment): augmentation_pipeline.transform_audio(audio_segment)