feat_dim, vocab_size

pull/665/head
Haoxin Ma 4 years ago
parent 3855522ee3
commit b9110af9d3

@ -137,7 +137,7 @@ class DeepSpeech2Trainer(Trainer):
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()
config.defrost() config.defrost()
config.data.keep_transcription_text = False config.collator.keep_transcription_text = False
config.data.manifest = config.data.train_manifest config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config) train_dataset = ManifestDataset.from_config(config)
@ -165,7 +165,7 @@ class DeepSpeech2Trainer(Trainer):
sortagrad=config.data.sortagrad, sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method) shuffle_method=config.data.shuffle_method)
collate_fn = SpeechCollator(config=config, keep_transcription_text=False) collate_fn = SpeechCollator.from_config(config)
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,

@ -104,50 +104,7 @@ class SpeechFeaturizer(object):
speech_segment.transcript) speech_segment.transcript)
return spec_feature, text_ids return spec_feature, text_ids
@property
def vocab_size(self):
"""Return the vocabulary size.
Returns:
int: Vocabulary size.
"""
return self._text_featurizer.vocab_size
@property
def vocab_list(self):
"""Return the vocabulary in list.
Returns:
List[str]:
"""
return self._text_featurizer.vocab_list
@property
def vocab_dict(self):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]:
"""
return self._text_featurizer.vocab_dict
@property
def feature_size(self):
"""Return the audio feature size.
Returns:
int: audio feature size.
"""
return self._audio_featurizer.feature_size
@property
def stride_ms(self):
"""time length in `ms` unit per frame
Returns:
float: time(ms)/frame
"""
return self._audio_featurizer.stride_ms
@property @property
def text_feature(self): def text_feature(self):

@ -82,7 +82,7 @@ def read_manifest(
] ]
if all(conditions): if all(conditions):
manifest.append(json_data) manifest.append(json_data)
return manifest return manifest, json_data["feat_shape"][-1]
def rms_to_db(rms: float): def rms_to_db(rms: float):

@ -22,6 +22,8 @@ from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment from deepspeech.frontend.speech import SpeechSegment
import io import io
import time import time
from yacs.config import CfgNode
from typing import Optional
from collections import namedtuple from collections import namedtuple
@ -33,51 +35,134 @@ logger = Log(__name__).getlog()
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
class SpeechCollator(): class SpeechCollator():
def __init__(self, config, keep_transcription_text=True): @classmethod
""" def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
Padding audio features with zeros to make them have the same shape (or default = CfgNode(
a user-defined shape) within one bach. dict(
augmentation_config="",
random_seed=0,
mean_std_filepath="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=True
))
if ``keep_transcription_text`` is False, text is token ids else is raw string. if config is not None:
config.merge_from_other_cfg(default)
return default
@classmethod
def from_config(cls, config):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
""" """
self._keep_transcription_text = keep_transcription_text assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.collator
assert 'vocab_filepath' in config.data
assert 'specgram_type' in config.collator
assert 'n_fft' in config.collator
assert config.collator
if isinstance(config.data.augmentation_config, (str, bytes)): if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.data.augmentation_config: if config.collator.augmentation_config:
aug_file = io.open( aug_file = io.open(
config.data.augmentation_config, mode='r', encoding='utf8') config.collator.augmentation_config, mode='r', encoding='utf8')
else: else:
aug_file = io.StringIO(initial_value='{}', newline='') aug_file = io.StringIO(initial_value='{}', newline='')
else: else:
aug_file = config.data.augmentation_config aug_file = config.collator.augmentation_config
assert isinstance(aug_file, io.StringIO) assert isinstance(aug_file, io.StringIO)
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
mean_std_filepath=config.collator.mean_std_filepath,
unit_type=config.collator.unit_type,
vocab_filepath=config.data.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
specgram_type=config.collator.specgram_type,
feat_dim=config.collator.feat_dim,
delta_delta=config.collator.delta_delta,
stride_ms=config.collator.stride_ms,
window_ms=config.collator.window_ms,
n_fft=config.collator.n_fft,
max_freq=config.collator.max_freq,
target_sample_rate=config.collator.target_sample_rate,
use_dB_normalization=config.collator.use_dB_normalization,
target_dB=config.collator.target_dB,
dither=config.collator.dither,
keep_transcription_text=config.collator.keep_transcription_text
)
return speech_collator
def __init__(self, aug_file, mean_std_filepath,
vocab_filepath, spm_model_prefix,
random_seed=0,
unit_type="char",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0,
keep_transcription_text=True):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
if ``keep_transcription_text`` is False, text is token ids else is raw string.
"""
self._keep_transcription_text = keep_transcription_text
self._local_data = TarLocalData(tar2info={}, tar2object={}) self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline( self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=aug_file.read(), augmentation_config=aug_file.read(),
random_seed=config.data.random_seed) random_seed=random_seed)
self._normalizer = FeatureNormalizer( self._normalizer = FeatureNormalizer(
config.data.mean_std_filepath) if config.data.mean_std_filepath else None mean_std_filepath) if mean_std_filepath else None
self._stride_ms = config.data.stride_ms self._stride_ms = stride_ms
self._target_sample_rate = config.data.target_sample_rate self._target_sample_rate = target_sample_rate
self._speech_featurizer = SpeechFeaturizer( self._speech_featurizer = SpeechFeaturizer(
unit_type=config.data.unit_type, unit_type=unit_type,
vocab_filepath=config.data.vocab_filepath, vocab_filepath=vocab_filepath,
spm_model_prefix=config.data.spm_model_prefix, spm_model_prefix=spm_model_prefix,
specgram_type=config.data.specgram_type, specgram_type=specgram_type,
feat_dim=config.data.feat_dim, feat_dim=feat_dim,
delta_delta=config.data.delta_delta, delta_delta=delta_delta,
stride_ms=config.data.stride_ms, stride_ms=stride_ms,
window_ms=config.data.window_ms, window_ms=window_ms,
n_fft=config.data.n_fft, n_fft=n_fft,
max_freq=config.data.max_freq, max_freq=max_freq,
target_sample_rate=config.data.target_sample_rate, target_sample_rate=target_sample_rate,
use_dB_normalization=config.data.use_dB_normalization, use_dB_normalization=use_dB_normalization,
target_dB=config.data.target_dB, target_dB=target_dB,
dither=config.data.dither) dither=dither)
def _parse_tar(self, file): def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object """Parse a tar file to get a tarfile object
@ -196,3 +281,28 @@ class SpeechCollator():
texts, padding_value=IGNORE_ID).astype(np.int64) texts, padding_value=IGNORE_ID).astype(np.int64)
text_lens = np.array(text_lens).astype(np.int64) text_lens = np.array(text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, padded_texts, text_lens return utts, padded_audios, audio_lens, padded_texts, text_lens
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._text_featurizer
self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms

@ -55,20 +55,6 @@ class ManifestDataset(Dataset):
min_output_len=0.0, min_output_len=0.0,
max_output_input_ratio=float('inf'), max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0, min_output_input_ratio=0.0,
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
raw_wav=True, # use raw_wav or kaldi feature
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
dither=1.0, # feature dither
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
random_seed=0,
keep_transcription_text=False,
batch_size=32, # batch size batch_size=32, # batch size
num_workers=0, # data loader workers num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True sortagrad=False, # sorted in first epoch when True
@ -116,21 +102,19 @@ class ManifestDataset(Dataset):
min_output_len=config.data.min_output_len, min_output_len=config.data.min_output_len,
max_output_input_ratio=config.data.max_output_input_ratio, max_output_input_ratio=config.data.max_output_input_ratio,
min_output_input_ratio=config.data.min_output_input_ratio, min_output_input_ratio=config.data.min_output_input_ratio,
stride_ms=config.data.stride_ms, )
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delta_delta,
dither=config.data.dither,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=config.data.keep_transcription_text)
return dataset return dataset
def _read_vocab(self, vocab_filepath):
"""Load vocabulary from file."""
vocab_lines = []
with open(vocab_filepath, 'r', encoding='utf-8') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
return vocab_list
def __init__(self, def __init__(self,
manifest_path, manifest_path,
unit_type, unit_type,
@ -143,20 +127,7 @@ class ManifestDataset(Dataset):
max_output_len=float('inf'), max_output_len=float('inf'),
min_output_len=0.0, min_output_len=0.0,
max_output_input_ratio=float('inf'), max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0, min_output_input_ratio=0.0):
stride_ms=10.0,
window_ms=20.0,
n_fft=None,
max_freq=None,
target_sample_rate=16000,
specgram_type='linear',
feat_dim=None,
delta_delta=False,
dither=1.0,
use_dB_normalization=True,
target_dB=-20,
random_seed=0,
keep_transcription_text=False):
"""Manifest Dataset """Manifest Dataset
Args: Args:
@ -186,30 +157,11 @@ class ManifestDataset(Dataset):
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
""" """
super().__init__() super().__init__()
self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate # self._rng = np.random.RandomState(random_seed)
self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=stride_ms,
window_ms=window_ms,
n_fft=n_fft,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB,
dither=dither)
self._rng = np.random.RandomState(random_seed)
self._keep_transcription_text = keep_transcription_text
# read manifest # read manifest
self._manifest = read_manifest( self._manifest, self._feature_size = read_manifest(
manifest_path=manifest_path, manifest_path=manifest_path,
max_input_len=max_input_len, max_input_len=max_input_len,
min_input_len=min_input_len, min_input_len=min_input_len,
@ -219,10 +171,60 @@ class ManifestDataset(Dataset):
min_output_input_ratio=min_output_input_ratio) min_output_input_ratio=min_output_input_ratio)
self._manifest.sort(key=lambda x: x["feat_shape"][0]) self._manifest.sort(key=lambda x: x["feat_shape"][0])
self._vocab_list = self._read_vocab(vocab_filepath)
@property @property
def manifest(self): def manifest(self):
return self._manifest return self._manifest
@property
def vocab_size(self):
"""Return the vocabulary size.
Returns:
int: Vocabulary size.
"""
return len(self._vocab_list)
@property
def vocab_list(self):
"""Return the vocabulary in list.
Returns:
List[str]:
"""
return self._vocab_list
@property
def vocab_dict(self):
"""Return the vocabulary in dict.
Returns:
Dict[str, int]:
"""
vocab_dict = dict(
[(token, idx) for (idx, token) in enumerate(self._vocab_list)])
return vocab_dict
@property
def feature_size(self):
"""Return the audio feature size.
Returns:
int: audio feature size.
"""
return self._feature_size
@property
def stride_ms(self):
"""time length in `ms` unit per frame
Returns:
float: time(ms)/frame
"""
return self._audio_featurizer.stride_ms
def __len__(self): def __len__(self):
return len(self._manifest) return len(self._manifest)

@ -4,9 +4,10 @@ data:
dev_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny test_manifest: data/manifest.tiny
mean_std_filepath: data/mean_std.json mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
batch_size: 2 batch_size: 4
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 max_input_len: 27.0
min_output_len: 0.0 min_output_len: 0.0
@ -29,6 +30,24 @@ data:
shuffle_method: batch_shuffle shuffle_method: batch_shuffle
num_workers: 0 num_workers: 0
collator:
augmentation_config: conf/augmentation.json
random_seed: 0
mean_std_filepath: data/mean_std.json
spm_model_prefix:
specgram_type: linear
feat_dim:
delta_delta: False
stride_ms: 10.0
window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True
target_dB: -20
dither: 1.0
keep_transcription_text: True
model: model:
num_conv_layers: 2 num_conv_layers: 2
num_rnn_layers: 3 num_rnn_layers: 3
@ -37,7 +56,7 @@ model:
share_rnn_weights: True share_rnn_weights: True
training: training:
n_epoch: 10 n_epoch: 21
lr: 1e-5 lr: 1e-5
lr_decay: 1.0 lr_decay: 1.0
weight_decay: 1e-06 weight_decay: 1e-06

Loading…
Cancel
Save