add mfcc feature for DS2

pull/2/head
Yibing Liu 8 years ago
parent 8ce9546710
commit ee5abbe37d

@ -38,7 +38,11 @@ python datasets/librispeech/librispeech.py --help
python compute_mean_std.py python compute_mean_std.py
``` ```
`python compute_mean_std.py` computes mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. `python compute_mean_std.py` computes mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. The default feature of audio data is power spectrum, currently the mfcc feature is also supported. To train and infer based on mfcc feature, you can regenerate this file by
```
python compute_mean_std.py --specgram_type mfcc
```
More help for arguments: More help for arguments:

@ -10,6 +10,12 @@ from data_utils.featurizer.audio_featurizer import AudioFeaturizer
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Computing mean and stddev for feature normalizer.') description='Computing mean and stddev for feature normalizer.')
parser.add_argument(
"--specgram_type",
default='linear',
type=str,
help="Feature type of audio data: 'linear' (power spectrum)"
" or 'mfcc'. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--manifest_path", "--manifest_path",
default='datasets/manifest.train', default='datasets/manifest.train',
@ -39,7 +45,7 @@ args = parser.parse_args()
def main(): def main():
augmentation_pipeline = AugmentationPipeline(args.augmentation_config) augmentation_pipeline = AugmentationPipeline(args.augmentation_config)
audio_featurizer = AudioFeaturizer() audio_featurizer = AudioFeaturizer(specgram_type=args.specgram_type)
def augment_and_featurize(audio_segment): def augment_and_featurize(audio_segment):
augmentation_pipeline.transform_audio(audio_segment) augmentation_pipeline.transform_audio(audio_segment)

@ -6,13 +6,15 @@ from __future__ import print_function
import numpy as np import numpy as np
from data_utils import utils from data_utils import utils
from data_utils.audio import AudioSegment from data_utils.audio import AudioSegment
from python_speech_features import mfcc
from python_speech_features import delta
class AudioFeaturizer(object): class AudioFeaturizer(object):
"""Audio featurizer, for extracting features from audio contents of """Audio featurizer, for extracting features from audio contents of
AudioSegment or SpeechSegment. AudioSegment or SpeechSegment.
Currently, it only supports feature type of linear spectrogram. Currently, it supports feature types of linear spectrogram and mfcc.
:param specgram_type: Specgram feature type. Options: 'linear'. :param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str :type specgram_type: str
@ -20,9 +22,10 @@ class AudioFeaturizer(object):
:type stride_ms: float :type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames. :param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float :type window_ms: float
:param max_freq: Used when specgram_type is 'linear', only FFT bins :param max_freq: When specgram_type is 'linear', only FFT bins
corresponding to frequencies between [0, max_freq] are corresponding to frequencies between [0, max_freq] are
returned. returned; when specgram_type is 'mfcc', max_feq is the
highest band edge of mel filters.
:types max_freq: None|float :types max_freq: None|float
:param target_sample_rate: Audio are resampled (if upsampling or :param target_sample_rate: Audio are resampled (if upsampling or
downsampling is allowed) to this before downsampling is allowed) to this before
@ -91,6 +94,9 @@ class AudioFeaturizer(object):
return self._compute_linear_specgram( return self._compute_linear_specgram(
samples, sample_rate, self._stride_ms, self._window_ms, samples, sample_rate, self._stride_ms, self._window_ms,
self._max_freq) self._max_freq)
elif self._specgram_type == 'mfcc':
return self._compute_mfcc(samples, sample_rate, self._stride_ms,
self._window_ms, self._max_freq)
else: else:
raise ValueError("Unknown specgram_type %s. " raise ValueError("Unknown specgram_type %s. "
"Supported values: linear." % self._specgram_type) "Supported values: linear." % self._specgram_type)
@ -142,3 +148,39 @@ class AudioFeaturizer(object):
# prepare fft frequency list # prepare fft frequency list
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
return fft, freqs return fft, freqs
def _compute_mfcc(self,
samples,
sample_rate,
stride_ms=10.0,
window_ms=20.0,
max_freq=None):
"""Compute mfcc from samples."""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
# compute 13 cepstral coefficients, and the first one is replaced
# by log(frame energy)
mfcc_feat = mfcc(
signal=samples,
samplerate=sample_rate,
winlen=0.001 * window_ms,
winstep=0.001 * stride_ms,
highfreq=max_freq)
# Deltas
d_mfcc_feat = delta(mfcc_feat, 2)
# Deltas-Deltas
dd_mfcc_feat = delta(d_mfcc_feat, 2)
# concat above three features
concat_mfcc_feat = [
np.concatenate((mfcc_feat[i], d_mfcc_feat[i], dd_mfcc_feat[i]))
for i in xrange(len(mfcc_feat))
]
# transpose to be consistent with the linear specgram situation
concat_mfcc_feat = np.transpose(concat_mfcc_feat)
return concat_mfcc_feat

@ -11,23 +11,24 @@ class SpeechFeaturizer(object):
"""Speech featurizer, for extracting features from both audio and transcript """Speech featurizer, for extracting features from both audio and transcript
contents of SpeechSegment. contents of SpeechSegment.
Currently, for audio parts, it only supports feature type of linear Currently, for audio parts, it supports feature types of linear
spectrogram; for transcript parts, it only supports char-level tokenizing spectrogram and mfcc; for transcript parts, it only supports char-level
and conversion into a list of token indices. Note that the token indexing tokenizing and conversion into a list of token indices. Note that the
order follows the given vocabulary file. token indexing order follows the given vocabulary file.
:param vocab_filepath: Filepath to load vocabulary for token indices :param vocab_filepath: Filepath to load vocabulary for token indices
conversion. conversion.
:type specgram_type: basestring :type specgram_type: basestring
:param specgram_type: Specgram feature type. Options: 'linear'. :param specgram_type: Specgram feature type. Options: 'linear', 'mfcc'.
:type specgram_type: str :type specgram_type: str
:param stride_ms: Striding size (in milliseconds) for generating frames. :param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float :type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames. :param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float :type window_ms: float
:param max_freq: Used when specgram_type is 'linear', only FFT bins :param max_freq: When specgram_type is 'linear', only FFT bins
corresponding to frequencies between [0, max_freq] are corresponding to frequencies between [0, max_freq] are
returned. returned; when specgram_type is 'mfcc', max_freq is the
highest band edge of mel filters.
:types max_freq: None|float :types max_freq: None|float
:param target_sample_rate: Speech are resampled (if upsampling or :param target_sample_rate: Speech are resampled (if upsampling or
downsampling is allowed) to this before downsampling is allowed) to this before

@ -2,3 +2,4 @@ wget==3.2
scipy==0.13.1 scipy==0.13.1
resampy==0.1.5 resampy==0.1.5
https://github.com/kpu/kenlm/archive/master.zip https://github.com/kpu/kenlm/archive/master.zip
python_speech_features

@ -53,6 +53,12 @@ parser.add_argument(
default=True, default=True,
type=distutils.util.strtobool, type=distutils.util.strtobool,
help="Use sortagrad or not. (default: %(default)s)") help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument(
"--specgram_type",
default='linear',
type=str,
help="Feature type of audio data: 'linear' (power spectrum)"
" or 'mfcc'. (default: %(default)s)")
parser.add_argument( parser.add_argument(
"--max_duration", "--max_duration",
default=27.0, default=27.0,
@ -130,6 +136,7 @@ def train():
augmentation_config=args.augmentation_config, augmentation_config=args.augmentation_config,
max_duration=args.max_duration, max_duration=args.max_duration,
min_duration=args.min_duration, min_duration=args.min_duration,
specgram_type=args.specgram_type,
num_threads=args.num_threads_data) num_threads=args.num_threads_data)
train_generator = data_generator() train_generator = data_generator()

Loading…
Cancel
Save