add fbank, refactor feature of speech

pull/578/head
Hui Zhang 5 years ago
parent ed793b30b7
commit 177f463daa

@ -41,3 +41,17 @@ class AugmentorBase():
:type audio_segment: AudioSegmenet|SpeechSegment :type audio_segment: AudioSegmenet|SpeechSegment
""" """
pass pass
@abstractmethod
def transform_spectrogram(self, spec_segment):
"""Adds various effects to the input spectrogram segment. Such effects
will augment the training data to make the model invariant to certain
types of time_mask or freq_mask in the real world, improving model's
generalization ability.
Note that this is an in-place transformation.
:param spec_segment: Spectrogram segment to add effects to.
:type spec_segment: Spectrogram
"""
pass

@ -17,6 +17,7 @@ import numpy as np
from deepspeech.frontend.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from deepspeech.frontend.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
from python_speech_features import mfcc from python_speech_features import mfcc
from python_speech_features import logfbank
from python_speech_features import delta from python_speech_features import delta
@ -49,7 +50,9 @@ class AudioFeaturizer(object):
""" """
def __init__(self, def __init__(self,
specgram_type='linear', specgram_type: str='linear',
feat_dim: int=None,
delta_delta: bool=False,
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
n_fft=None, n_fft=None,
@ -58,6 +61,8 @@ class AudioFeaturizer(object):
use_dB_normalization=True, use_dB_normalization=True,
target_dB=-20): target_dB=-20):
self._specgram_type = specgram_type self._specgram_type = specgram_type
self._feat_dim = feat_dim
self._delta_delta = delta_delta
self._stride_ms = stride_ms self._stride_ms = stride_ms
self._window_ms = window_ms self._window_ms = window_ms
self._max_freq = max_freq self._max_freq = max_freq
@ -110,7 +115,12 @@ class AudioFeaturizer(object):
1) 1)
elif self._specgram_type == 'mfcc': elif self._specgram_type == 'mfcc':
# mfcc, delta, delta-delta # mfcc, delta, delta-delta
feat_dim = int(13 * 3) feat_dim = int(self._feat_dim *
3) if self._delta_delta else int(self._feat_dim)
elif self._specgram_type == 'fbank':
# fbank, delta, delta-delta
feat_dim = int(self._feat_dim *
3) if self._delta_delta else int(self._feat_dim)
else: else:
raise ValueError("Unknown specgram_type %s. " raise ValueError("Unknown specgram_type %s. "
"Supported values: linear." % self._specgram_type) "Supported values: linear." % self._specgram_type)
@ -123,8 +133,23 @@ class AudioFeaturizer(object):
samples, sample_rate, self._stride_ms, self._window_ms, samples, sample_rate, self._stride_ms, self._window_ms,
self._max_freq) self._max_freq)
elif self._specgram_type == 'mfcc': elif self._specgram_type == 'mfcc':
return self._compute_mfcc(samples, sample_rate, self._stride_ms, return self._compute_mfcc(
self._window_ms, self._max_freq) samples,
sample_rate,
self._stride_ms,
self._feat_dim,
self._window_ms,
self._max_freq,
delta_delta=self._delta_delta)
elif self._specgram_type == 'fbank':
return self._compute_fbank(
samples,
sample_rate,
self._stride_ms,
self._feat_dim,
self._window_ms,
self._max_freq,
delta_delta=self._delta_delta)
else: else:
raise ValueError("Unknown specgram_type %s. " raise ValueError("Unknown specgram_type %s. "
"Supported values: linear." % self._specgram_type) "Supported values: linear." % self._specgram_type)
@ -179,13 +204,54 @@ class AudioFeaturizer(object):
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
return fft, freqs return fft, freqs
def _concat_delta_delta(self, feat):
"""append delat, delta-delta feature.
Args:
feat (np.ndarray): (D, T)
Returns:
np.ndarray: feat with delta-delta, (3*D, T)
"""
feat = np.transpose(feat)
# Deltas
d_feat = delta(feat, 2)
# Deltas-Deltas
dd_feat = delta(feat, 2)
# transpose
feat = np.transpose(feat)
d_feat = np.transpose(d_feat)
dd_feat = np.transpose(dd_feat)
# concat above three features
concat_feat = np.concatenate((feat, d_feat, dd_feat))
return concat_feat
def _compute_mfcc(self, def _compute_mfcc(self,
samples, samples,
sample_rate, sample_rate,
feat_dim=13,
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
max_freq=None): max_freq=None,
"""Compute mfcc from samples.""" delta_delta=True):
"""Compute mfcc from samples.
Args:
samples (np.ndarray): the audio signal from which to compute features. Should be an N*1 array
sample_rate (float): the sample rate of the signal we are working with, in Hz.
feat_dim (int): the number of cepstrum to return, default 13.
stride_ms (float, optional): stride length in ms. Defaults to 10.0.
window_ms (float, optional): window length in ms. Defaults to 20.0.
max_freq ([type], optional): highest band edge of mel filters. In Hz, default is samplerate/2. Defaults to None.
delta_delta (bool, optional): Whether with delta delta. Defaults to False.
Raises:
ValueError: max_freq > samplerate/2
ValueError: stride_ms > window_ms
Returns:
np.ndarray: mfcc feature, (D, T).
"""
if max_freq is None: if max_freq is None:
max_freq = sample_rate / 2 max_freq = sample_rate / 2
if max_freq > sample_rate / 2: if max_freq > sample_rate / 2:
@ -195,22 +261,73 @@ class AudioFeaturizer(object):
raise ValueError("Stride size must not be greater than " raise ValueError("Stride size must not be greater than "
"window size.") "window size.")
# compute the 13 cepstral coefficients, and the first one is replaced # compute the 13 cepstral coefficients, and the first one is replaced
# by log(frame energy) # by log(frame energy), (T, D)
mfcc_feat = mfcc( mfcc_feat = mfcc(
signal=samples, signal=samples,
samplerate=sample_rate, samplerate=sample_rate,
winlen=0.001 * window_ms, winlen=0.001 * window_ms,
winstep=0.001 * stride_ms, winstep=0.001 * stride_ms,
highfreq=max_freq) numcep=feat_dim,
# Deltas nfilt=2 * feat_dim,
d_mfcc_feat = delta(mfcc_feat, 2) nfft=None,
# Deltas-Deltas lowfreq=0,
dd_mfcc_feat = delta(d_mfcc_feat, 2) highfreq=max_freq,
# transpose preemph=0.97,
ceplifter=22,
appendEnergy=True,
winfunc=lambda x: np.ones((x, )))
mfcc_feat = np.transpose(mfcc_feat) mfcc_feat = np.transpose(mfcc_feat)
d_mfcc_feat = np.transpose(d_mfcc_feat) if delta_delta:
dd_mfcc_feat = np.transpose(dd_mfcc_feat) mfcc_feat = self._concat_delta_delta(mfcc_feat)
# concat above three features return mfcc_feat
concat_mfcc_feat = np.concatenate(
(mfcc_feat, d_mfcc_feat, dd_mfcc_feat)) def _compute_fbank(self,
return concat_mfcc_feat samples,
sample_rate,
feat_dim=26,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
delta_delta=False):
"""Compute logfbank from samples.
Args:
samples (np.ndarray): the audio signal from which to compute features. Should be an N*1 array
sample_rate (float): the sample rate of the signal we are working with, in Hz.
feat_dim (int): the number of cepstrum to return, default 13.
stride_ms (float, optional): stride length in ms. Defaults to 10.0.
window_ms (float, optional): window length in ms. Defaults to 20.0.
max_freq (float, optional): highest band edge of mel filters. In Hz, default is samplerate/2. Defaults to None.
delta_delta (bool, optional): Whether with delta delta. Defaults to False.
Raises:
ValueError: max_freq > samplerate/2
ValueError: stride_ms > window_ms
Returns:
np.ndarray: mfcc feature, (D, T).
"""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must not be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
#(T, D)
fbank_feat = logfbank(
signal=samples,
samplerate=sample_rate,
winlen=0.001 * window_ms,
winstep=0.001 * stride_ms,
nfilt=feat_dim,
nfft=512,
lowfreq=max_freq,
highfreq=None,
preemph=0.97,
winfunc=lambda x: np.ones((x, )))
fbank_feat = np.transpose(fbank_feat)
if delta_delta:
fbank_feat = self._concat_delta_delta(fbank_feat)
return fbank_feat

@ -56,6 +56,8 @@ class SpeechFeaturizer(object):
vocab_filepath, vocab_filepath,
spm_model_prefix=None, spm_model_prefix=None,
specgram_type='linear', specgram_type='linear',
feat_dim=13,
delta_delta=True,
stride_ms=10.0, stride_ms=10.0,
window_ms=20.0, window_ms=20.0,
n_fft=None, n_fft=None,
@ -65,6 +67,8 @@ class SpeechFeaturizer(object):
target_dB=-20): target_dB=-20):
self._audio_featurizer = AudioFeaturizer( self._audio_featurizer = AudioFeaturizer(
specgram_type=specgram_type, specgram_type=specgram_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=stride_ms, stride_ms=stride_ms,
window_ms=window_ms, window_ms=window_ms,
n_fft=n_fft, n_fft=n_fft,
@ -82,17 +86,22 @@ class SpeechFeaturizer(object):
2. For transcript parts, keep the original text or convert text string 2. For transcript parts, keep the original text or convert text string
to a list of token indices in char-level. to a list of token indices in char-level.
:param audio_segment: Speech segment to extract features from. Args:
:type audio_segment: SpeechSegment speech_segment (SpeechSegment): Speech segment to extract features from.
:return: A tuple of 1) spectrogram audio feature in 2darray, 2) list of keep_transcription_text (bool): True, keep transcript text, False, token ids
char-level token indices.
:rtype: tuple Returns:
tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices.
""" """
audio_feature = self._audio_featurizer.featurize(speech_segment) spec_feature = self._audio_featurizer.featurize(speech_segment)
if keep_transcription_text: if keep_transcription_text:
return audio_feature, speech_segment.transcript return spec_feature, speech_segment.transcript
text_ids = self._text_featurizer.featurize(speech_segment.transcript) if speech_segment.has_token:
return audio_feature, text_ids text_ids = speech_segment.token_ids
else:
text_ids = self._text_featurizer.featurize(
speech_segment.transcript)
return spec_feature, text_ids
@property @property
def vocab_size(self): def vocab_size(self):

@ -24,7 +24,12 @@ class SpeechSegment(AudioSegment):
AudioSegment (AudioSegment): Audio Segment AudioSegment (AudioSegment): Audio Segment
""" """
def __init__(self, samples, sample_rate, transcript): def __init__(self,
samples,
sample_rate,
transcript,
tokens=None,
token_ids=None):
"""Speech segment abstraction, a subclass of AudioSegment, """Speech segment abstraction, a subclass of AudioSegment,
with an additional transcript. with an additional transcript.
@ -32,9 +37,13 @@ class SpeechSegment(AudioSegment):
samples (ndarray.float32): Audio samples [num_samples x num_channels]. samples (ndarray.float32): Audio samples [num_samples x num_channels].
sample_rate (int): Audio sample rate. sample_rate (int): Audio sample rate.
transcript (str): Transcript text for the speech. transcript (str): Transcript text for the speech.
tokens (List[str], optinal): Transcript tokens for the speech.
token_ids (List[int], optional): Transcript token ids for the speech.
""" """
AudioSegment.__init__(self, samples, sample_rate) AudioSegment.__init__(self, samples, sample_rate)
self._transcript = transcript self._transcript = transcript
self._tokens = tokens
self._token_ids = token_ids
def __eq__(self, other): def __eq__(self, other):
"""Return whether two objects are equal. """Return whether two objects are equal.
@ -46,6 +55,11 @@ class SpeechSegment(AudioSegment):
return False return False
if self._transcript != other._transcript: if self._transcript != other._transcript:
return False return False
if self.has_token and other.has_token:
if self._tokens != other._tokens:
return False
if self._token_ids != other._token_ids:
return False
return True return True
def __ne__(self, other): def __ne__(self, other):
@ -53,33 +67,39 @@ class SpeechSegment(AudioSegment):
return not self.__eq__(other) return not self.__eq__(other)
@classmethod @classmethod
def from_file(cls, filepath, transcript): def from_file(cls, filepath, transcript, tokens=None, token_ids=None):
"""Create speech segment from audio file and corresponding transcript. """Create speech segment from audio file and corresponding transcript.
:param filepath: Filepath or file object to audio file. Args:
:type filepath: str|file filepath (str|file): Filepath or file object to audio file.
:param transcript: Transcript text for the speech. transcript (str): Transcript text for the speech.
:type transript: str tokens (List[str], optional): text tokens. Defaults to None.
:return: Speech segment instance. token_ids (List[int], optional): text token ids. Defaults to None.
:rtype: SpeechSegment
Returns:
SpeechSegment: Speech segment instance.
""" """
audio = AudioSegment.from_file(filepath) audio = AudioSegment.from_file(filepath)
return cls(audio.samples, audio.sample_rate, transcript) return cls(audio.samples, audio.sample_rate, transcript, tokens,
token_ids)
@classmethod @classmethod
def from_bytes(cls, bytes, transcript): def from_bytes(cls, bytes, transcript, tokens=None, token_ids=None):
"""Create speech segment from a byte string and corresponding """Create speech segment from a byte string and corresponding
transcript.
:param bytes: Byte string containing audio samples. Args:
:type bytes: str filepath (str|file): Filepath or file object to audio file.
:param transcript: Transcript text for the speech. transcript (str): Transcript text for the speech.
:type transript: str tokens (List[str], optional): text tokens. Defaults to None.
:return: Speech segment instance. token_ids (List[int], optional): text token ids. Defaults to None.
:rtype: Speech Segment
Returns:
SpeechSegment: Speech segment instance.
""" """
audio = AudioSegment.from_bytes(bytes) audio = AudioSegment.from_bytes(bytes)
return cls(audio.samples, audio.sample_rate, transcript) return cls(audio.samples, audio.sample_rate, transcript, tokens,
token_ids)
@classmethod @classmethod
def concatenate(cls, *segments): def concatenate(cls, *segments):
@ -98,6 +118,8 @@ class SpeechSegment(AudioSegment):
raise ValueError("No speech segments are given to concatenate.") raise ValueError("No speech segments are given to concatenate.")
sample_rate = segments[0]._sample_rate sample_rate = segments[0]._sample_rate
transcripts = "" transcripts = ""
tokens = []
token_ids = []
for seg in segments: for seg in segments:
if sample_rate != seg._sample_rate: if sample_rate != seg._sample_rate:
raise ValueError("Can't concatenate segments with " raise ValueError("Can't concatenate segments with "
@ -106,11 +128,20 @@ class SpeechSegment(AudioSegment):
raise TypeError("Only speech segments of the same type " raise TypeError("Only speech segments of the same type "
"instance can be concatenated.") "instance can be concatenated.")
transcripts += seg._transcript transcripts += seg._transcript
if self.has_token:
tokens += seg._tokens
token_ids += seg._token_ids
samples = np.concatenate([seg.samples for seg in segments]) samples = np.concatenate([seg.samples for seg in segments])
return cls(samples, sample_rate, transcripts) return cls(samples, sample_rate, transcripts, tokens, token_ids)
@classmethod @classmethod
def slice_from_file(cls, filepath, transcript, start=None, end=None): def slice_from_file(cls,
filepath,
transcript,
tokens=None,
token_ids=None,
start=None,
end=None):
"""Loads a small section of an speech without having to load """Loads a small section of an speech without having to load
the entire file into the memory which can be incredibly wasteful. the entire file into the memory which can be incredibly wasteful.
@ -132,28 +163,54 @@ class SpeechSegment(AudioSegment):
:rtype: SpeechSegment :rtype: SpeechSegment
""" """
audio = AudioSegment.slice_from_file(filepath, start, end) audio = AudioSegment.slice_from_file(filepath, start, end)
return cls(audio.samples, audio.sample_rate, transcript) return cls(audio.samples, audio.sample_rate, transcript, tokens,
token_ids)
@classmethod @classmethod
def make_silence(cls, duration, sample_rate): def make_silence(cls, duration, sample_rate):
"""Creates a silent speech segment of the given duration and """Creates a silent speech segment of the given duration and
sample rate, transcript will be an empty string. sample rate, transcript will be an empty string.
:param duration: Length of silence in seconds. Args:
:type duration: float duration (float): Length of silence in seconds.
:param sample_rate: Sample rate. sample_rate (float): Sample rate.
:type sample_rate: float
:return: Silence of the given duration. Returns:
:rtype: SpeechSegment SpeechSegment: Silence of the given duration.
""" """
audio = AudioSegment.make_silence(duration, sample_rate) audio = AudioSegment.make_silence(duration, sample_rate)
return cls(audio.samples, audio.sample_rate, "") return cls(audio.samples, audio.sample_rate, "")
@property
def has_token(self):
if self._tokens or self._token_ids:
return True
return False
@property @property
def transcript(self): def transcript(self):
"""Return the transcript text. """Return the transcript text.
:return: Transcript text for the speech. Returns:
:rtype: str str: Transcript text for the speech.
""" """
return self._transcript return self._transcript
@property
def tokens(self):
"""Return the transcript text tokens.
Returns:
List[str]: text tokens.
"""
return self._tokens
@property
def token_ids(self):
"""Return the transcript text token ids.
Returns:
List[int]: text token ids.
"""
return self._token_ids

@ -26,8 +26,14 @@ add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('num_samples', int, 2000, "# of samples to for statistics.") add_arg('num_samples', int, 2000, "# of samples to for statistics.")
add_arg('specgram_type', str, add_arg('specgram_type', str,
'linear', 'linear',
"Audio feature type. Options: linear, mfcc.", "Audio feature type. Options: linear, mfcc, fbank.",
choices=['linear', 'mfcc']) choices=['linear', 'mfcc', 'fbank'])
add_arg('feat_dim', int,
13,
"Audio feature dim.")
add_arg('delta_delta', bool,
False,
"Audio feature with delta delta.")
add_arg('manifest_path', str, add_arg('manifest_path', str,
'data/librispeech/manifest.train', 'data/librispeech/manifest.train',
"Filepath of manifest to compute normalizer's mean and stddev.") "Filepath of manifest to compute normalizer's mean and stddev.")
@ -42,7 +48,10 @@ def main():
print_arguments(args) print_arguments(args)
augmentation_pipeline = AugmentationPipeline('{}') augmentation_pipeline = AugmentationPipeline('{}')
audio_featurizer = AudioFeaturizer(specgram_type=args.specgram_type) audio_featurizer = AudioFeaturizer(
specgram_type=args.specgram_type,
feat_dim=args.feat_dim,
delta_delta=args.delta_delta)
def augment_and_featurize(audio_segment): def augment_and_featurize(audio_segment):
augmentation_pipeline.transform_audio(audio_segment) augmentation_pipeline.transform_audio(audio_segment)

Loading…
Cancel
Save