add fbank, refactor feature of speech

pull/578/head
Hui Zhang 5 years ago
parent ed793b30b7
commit 177f463daa

@ -41,3 +41,17 @@ class AugmentorBase():
:type audio_segment: AudioSegmenet|SpeechSegment
"""
pass
@abstractmethod
def transform_spectrogram(self, spec_segment):
"""Adds various effects to the input spectrogram segment. Such effects
will augment the training data to make the model invariant to certain
types of time_mask or freq_mask in the real world, improving model's
generalization ability.
Note that this is an in-place transformation.
:param spec_segment: Spectrogram segment to add effects to.
:type spec_segment: Spectrogram
"""
pass

@ -17,6 +17,7 @@ import numpy as np
from deepspeech.frontend.utility import read_manifest
from deepspeech.frontend.audio import AudioSegment
from python_speech_features import mfcc
from python_speech_features import logfbank
from python_speech_features import delta
@ -49,7 +50,9 @@ class AudioFeaturizer(object):
"""
def __init__(self,
specgram_type='linear',
specgram_type: str='linear',
feat_dim: int=None,
delta_delta: bool=False,
stride_ms=10.0,
window_ms=20.0,
n_fft=None,
@ -58,6 +61,8 @@ class AudioFeaturizer(object):
use_dB_normalization=True,
target_dB=-20):
self._specgram_type = specgram_type
self._feat_dim = feat_dim
self._delta_delta = delta_delta
self._stride_ms = stride_ms
self._window_ms = window_ms
self._max_freq = max_freq
@ -110,7 +115,12 @@ class AudioFeaturizer(object):
1)
elif self._specgram_type == 'mfcc':
# mfcc, delta, delta-delta
feat_dim = int(13 * 3)
feat_dim = int(self._feat_dim *
3) if self._delta_delta else int(self._feat_dim)
elif self._specgram_type == 'fbank':
# fbank, delta, delta-delta
feat_dim = int(self._feat_dim *
3) if self._delta_delta else int(self._feat_dim)
else:
raise ValueError("Unknown specgram_type %s. "
"Supported values: linear." % self._specgram_type)
@ -123,8 +133,23 @@ class AudioFeaturizer(object):
samples, sample_rate, self._stride_ms, self._window_ms,
self._max_freq)
elif self._specgram_type == 'mfcc':
return self._compute_mfcc(samples, sample_rate, self._stride_ms,
self._window_ms, self._max_freq)
return self._compute_mfcc(
samples,
sample_rate,
self._stride_ms,
self._feat_dim,
self._window_ms,
self._max_freq,
delta_delta=self._delta_delta)
elif self._specgram_type == 'fbank':
return self._compute_fbank(
samples,
sample_rate,
self._stride_ms,
self._feat_dim,
self._window_ms,
self._max_freq,
delta_delta=self._delta_delta)
else:
raise ValueError("Unknown specgram_type %s. "
"Supported values: linear." % self._specgram_type)
@ -179,13 +204,54 @@ class AudioFeaturizer(object):
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
return fft, freqs
def _concat_delta_delta(self, feat):
"""append delat, delta-delta feature.
Args:
feat (np.ndarray): (D, T)
Returns:
np.ndarray: feat with delta-delta, (3*D, T)
"""
feat = np.transpose(feat)
# Deltas
d_feat = delta(feat, 2)
# Deltas-Deltas
dd_feat = delta(feat, 2)
# transpose
feat = np.transpose(feat)
d_feat = np.transpose(d_feat)
dd_feat = np.transpose(dd_feat)
# concat above three features
concat_feat = np.concatenate((feat, d_feat, dd_feat))
return concat_feat
def _compute_mfcc(self,
samples,
sample_rate,
feat_dim=13,
stride_ms=10.0,
window_ms=20.0,
max_freq=None):
"""Compute mfcc from samples."""
max_freq=None,
delta_delta=True):
"""Compute mfcc from samples.
Args:
samples (np.ndarray): the audio signal from which to compute features. Should be an N*1 array
sample_rate (float): the sample rate of the signal we are working with, in Hz.
feat_dim (int): the number of cepstrum to return, default 13.
stride_ms (float, optional): stride length in ms. Defaults to 10.0.
window_ms (float, optional): window length in ms. Defaults to 20.0.
max_freq ([type], optional): highest band edge of mel filters. In Hz, default is samplerate/2. Defaults to None.
delta_delta (bool, optional): Whether with delta delta. Defaults to False.
Raises:
ValueError: max_freq > samplerate/2
ValueError: stride_ms > window_ms
Returns:
np.ndarray: mfcc feature, (D, T).
"""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
@ -195,22 +261,73 @@ class AudioFeaturizer(object):
raise ValueError("Stride size must not be greater than "
"window size.")
# compute the 13 cepstral coefficients, and the first one is replaced
# by log(frame energy)
# by log(frame energy), (T, D)
mfcc_feat = mfcc(
signal=samples,
samplerate=sample_rate,
winlen=0.001 * window_ms,
winstep=0.001 * stride_ms,
highfreq=max_freq)
# Deltas
d_mfcc_feat = delta(mfcc_feat, 2)
# Deltas-Deltas
dd_mfcc_feat = delta(d_mfcc_feat, 2)
# transpose
numcep=feat_dim,
nfilt=2 * feat_dim,
nfft=None,
lowfreq=0,
highfreq=max_freq,
preemph=0.97,
ceplifter=22,
appendEnergy=True,
winfunc=lambda x: np.ones((x, )))
mfcc_feat = np.transpose(mfcc_feat)
d_mfcc_feat = np.transpose(d_mfcc_feat)
dd_mfcc_feat = np.transpose(dd_mfcc_feat)
# concat above three features
concat_mfcc_feat = np.concatenate(
(mfcc_feat, d_mfcc_feat, dd_mfcc_feat))
return concat_mfcc_feat
if delta_delta:
mfcc_feat = self._concat_delta_delta(mfcc_feat)
return mfcc_feat
def _compute_fbank(self,
samples,
sample_rate,
feat_dim=26,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
delta_delta=False):
"""Compute logfbank from samples.
Args:
samples (np.ndarray): the audio signal from which to compute features. Should be an N*1 array
sample_rate (float): the sample rate of the signal we are working with, in Hz.
feat_dim (int): the number of cepstrum to return, default 13.
stride_ms (float, optional): stride length in ms. Defaults to 10.0.
window_ms (float, optional): window length in ms. Defaults to 20.0.
max_freq (float, optional): highest band edge of mel filters. In Hz, default is samplerate/2. Defaults to None.
delta_delta (bool, optional): Whether with delta delta. Defaults to False.
Raises:
ValueError: max_freq > samplerate/2
ValueError: stride_ms > window_ms
Returns:
np.ndarray: mfcc feature, (D, T).
"""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must not be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
#(T, D)
fbank_feat = logfbank(
signal=samples,
samplerate=sample_rate,
winlen=0.001 * window_ms,
winstep=0.001 * stride_ms,
nfilt=feat_dim,
nfft=512,
lowfreq=max_freq,
highfreq=None,
preemph=0.97,
winfunc=lambda x: np.ones((x, )))
fbank_feat = np.transpose(fbank_feat)
if delta_delta:
fbank_feat = self._concat_delta_delta(fbank_feat)
return fbank_feat

@ -56,6 +56,8 @@ class SpeechFeaturizer(object):
vocab_filepath,
spm_model_prefix=None,
specgram_type='linear',
feat_dim=13,
delta_delta=True,
stride_ms=10.0,
window_ms=20.0,
n_fft=None,
@ -65,6 +67,8 @@ class SpeechFeaturizer(object):
target_dB=-20):
self._audio_featurizer = AudioFeaturizer(
specgram_type=specgram_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=stride_ms,
window_ms=window_ms,
n_fft=n_fft,
@ -82,17 +86,22 @@ class SpeechFeaturizer(object):
2. For transcript parts, keep the original text or convert text string
to a list of token indices in char-level.
:param audio_segment: Speech segment to extract features from.
:type audio_segment: SpeechSegment
:return: A tuple of 1) spectrogram audio feature in 2darray, 2) list of
char-level token indices.
:rtype: tuple
Args:
speech_segment (SpeechSegment): Speech segment to extract features from.
keep_transcription_text (bool): True, keep transcript text, False, token ids
Returns:
tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices.
"""
audio_feature = self._audio_featurizer.featurize(speech_segment)
spec_feature = self._audio_featurizer.featurize(speech_segment)
if keep_transcription_text:
return audio_feature, speech_segment.transcript
text_ids = self._text_featurizer.featurize(speech_segment.transcript)
return audio_feature, text_ids
return spec_feature, speech_segment.transcript
if speech_segment.has_token:
text_ids = speech_segment.token_ids
else:
text_ids = self._text_featurizer.featurize(
speech_segment.transcript)
return spec_feature, text_ids
@property
def vocab_size(self):

@ -24,7 +24,12 @@ class SpeechSegment(AudioSegment):
AudioSegment (AudioSegment): Audio Segment
"""
def __init__(self, samples, sample_rate, transcript):
def __init__(self,
samples,
sample_rate,
transcript,
tokens=None,
token_ids=None):
"""Speech segment abstraction, a subclass of AudioSegment,
with an additional transcript.
@ -32,9 +37,13 @@ class SpeechSegment(AudioSegment):
samples (ndarray.float32): Audio samples [num_samples x num_channels].
sample_rate (int): Audio sample rate.
transcript (str): Transcript text for the speech.
tokens (List[str], optinal): Transcript tokens for the speech.
token_ids (List[int], optional): Transcript token ids for the speech.
"""
AudioSegment.__init__(self, samples, sample_rate)
self._transcript = transcript
self._tokens = tokens
self._token_ids = token_ids
def __eq__(self, other):
"""Return whether two objects are equal.
@ -46,6 +55,11 @@ class SpeechSegment(AudioSegment):
return False
if self._transcript != other._transcript:
return False
if self.has_token and other.has_token:
if self._tokens != other._tokens:
return False
if self._token_ids != other._token_ids:
return False
return True
def __ne__(self, other):
@ -53,33 +67,39 @@ class SpeechSegment(AudioSegment):
return not self.__eq__(other)
@classmethod
def from_file(cls, filepath, transcript):
def from_file(cls, filepath, transcript, tokens=None, token_ids=None):
"""Create speech segment from audio file and corresponding transcript.
:param filepath: Filepath or file object to audio file.
:type filepath: str|file
:param transcript: Transcript text for the speech.
:type transript: str
:return: Speech segment instance.
:rtype: SpeechSegment
Args:
filepath (str|file): Filepath or file object to audio file.
transcript (str): Transcript text for the speech.
tokens (List[str], optional): text tokens. Defaults to None.
token_ids (List[int], optional): text token ids. Defaults to None.
Returns:
SpeechSegment: Speech segment instance.
"""
audio = AudioSegment.from_file(filepath)
return cls(audio.samples, audio.sample_rate, transcript)
return cls(audio.samples, audio.sample_rate, transcript, tokens,
token_ids)
@classmethod
def from_bytes(cls, bytes, transcript):
def from_bytes(cls, bytes, transcript, tokens=None, token_ids=None):
"""Create speech segment from a byte string and corresponding
transcript.
:param bytes: Byte string containing audio samples.
:type bytes: str
:param transcript: Transcript text for the speech.
:type transript: str
:return: Speech segment instance.
:rtype: Speech Segment
Args:
filepath (str|file): Filepath or file object to audio file.
transcript (str): Transcript text for the speech.
tokens (List[str], optional): text tokens. Defaults to None.
token_ids (List[int], optional): text token ids. Defaults to None.
Returns:
SpeechSegment: Speech segment instance.
"""
audio = AudioSegment.from_bytes(bytes)
return cls(audio.samples, audio.sample_rate, transcript)
return cls(audio.samples, audio.sample_rate, transcript, tokens,
token_ids)
@classmethod
def concatenate(cls, *segments):
@ -98,6 +118,8 @@ class SpeechSegment(AudioSegment):
raise ValueError("No speech segments are given to concatenate.")
sample_rate = segments[0]._sample_rate
transcripts = ""
tokens = []
token_ids = []
for seg in segments:
if sample_rate != seg._sample_rate:
raise ValueError("Can't concatenate segments with "
@ -106,11 +128,20 @@ class SpeechSegment(AudioSegment):
raise TypeError("Only speech segments of the same type "
"instance can be concatenated.")
transcripts += seg._transcript
if self.has_token:
tokens += seg._tokens
token_ids += seg._token_ids
samples = np.concatenate([seg.samples for seg in segments])
return cls(samples, sample_rate, transcripts)
return cls(samples, sample_rate, transcripts, tokens, token_ids)
@classmethod
def slice_from_file(cls, filepath, transcript, start=None, end=None):
def slice_from_file(cls,
filepath,
transcript,
tokens=None,
token_ids=None,
start=None,
end=None):
"""Loads a small section of an speech without having to load
the entire file into the memory which can be incredibly wasteful.
@ -132,28 +163,54 @@ class SpeechSegment(AudioSegment):
:rtype: SpeechSegment
"""
audio = AudioSegment.slice_from_file(filepath, start, end)
return cls(audio.samples, audio.sample_rate, transcript)
return cls(audio.samples, audio.sample_rate, transcript, tokens,
token_ids)
@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent speech segment of the given duration and
sample rate, transcript will be an empty string.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silence of the given duration.
:rtype: SpeechSegment
Args:
duration (float): Length of silence in seconds.
sample_rate (float): Sample rate.
Returns:
SpeechSegment: Silence of the given duration.
"""
audio = AudioSegment.make_silence(duration, sample_rate)
return cls(audio.samples, audio.sample_rate, "")
@property
def has_token(self):
if self._tokens or self._token_ids:
return True
return False
@property
def transcript(self):
"""Return the transcript text.
:return: Transcript text for the speech.
:rtype: str
Returns:
str: Transcript text for the speech.
"""
return self._transcript
@property
def tokens(self):
"""Return the transcript text tokens.
Returns:
List[str]: text tokens.
"""
return self._tokens
@property
def token_ids(self):
"""Return the transcript text token ids.
Returns:
List[int]: text token ids.
"""
return self._token_ids

@ -26,8 +26,14 @@ add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('num_samples', int, 2000, "# of samples to for statistics.")
add_arg('specgram_type', str,
'linear',
"Audio feature type. Options: linear, mfcc.",
choices=['linear', 'mfcc'])
"Audio feature type. Options: linear, mfcc, fbank.",
choices=['linear', 'mfcc', 'fbank'])
add_arg('feat_dim', int,
13,
"Audio feature dim.")
add_arg('delta_delta', bool,
False,
"Audio feature with delta delta.")
add_arg('manifest_path', str,
'data/librispeech/manifest.train',
"Filepath of manifest to compute normalizer's mean and stddev.")
@ -42,7 +48,10 @@ def main():
print_arguments(args)
augmentation_pipeline = AugmentationPipeline('{}')
audio_featurizer = AudioFeaturizer(specgram_type=args.specgram_type)
audio_featurizer = AudioFeaturizer(
specgram_type=args.specgram_type,
feat_dim=args.feat_dim,
delta_delta=args.delta_delta)
def augment_and_featurize(audio_segment):
augmentation_pipeline.transform_audio(audio_segment)

Loading…
Cancel
Save