Merge pull request #665 from PaddlePaddle/spec_aug

Spec aug 放入collator中
pull/676/head
Hui Zhang 4 years ago committed by GitHub
commit 6487ca6022
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,74 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from yacs.config import CfgNode as CN from yacs.config import CfgNode
from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester
from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.deepspeech2 import DeepSpeech2Model
_C = CN() _C = CfgNode()
_C.data = CN(
dict(
train_manifest="",
dev_manifest="",
test_manifest="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
mean_std_filepath="",
augmentation_config="",
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delat_delta=False, # 'mfcc', 'fbank'
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
random_seed=0,
keep_transcription_text=False,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle'
))
_C.model = CN( _C.data = ManifestDataset.params()
dict(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
DeepSpeech2Model.params(_C.model) _C.collator = SpeechCollator.params()
_C.training = CN( _C.model = DeepSpeech2Model.params()
dict(
lr=5e-4, # learning rate
lr_decay=1.0, # learning rate decay
weight_decay=1e-6, # the coeff of weight decay
global_grad_clip=5.0, # the global norm clip
n_epoch=50, # train epochs
))
_C.decoding = CN( _C.training = DeepSpeech2Trainer.params()
dict(
alpha=2.5, # Coef of LM for beam search. _C.decoding = DeepSpeech2Tester.params()
beta=0.3, # Coef of WC for beam search.
cutoff_prob=1.0, # Cutoff probability for pruning.
cutoff_top_n=40, # Cutoff number for pruning.
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=500, # Beam search width.
batch_size=128, # decoding batch size
))
def get_cfg_defaults(): def get_cfg_defaults():

@ -15,11 +15,13 @@
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Optional
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from yacs.config import CfgNode
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
@ -33,11 +35,26 @@ from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
class DeepSpeech2Trainer(Trainer): class DeepSpeech2Trainer(Trainer):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# training config
default = CfgNode(
dict(
lr=5e-4, # learning rate
lr_decay=1.0, # learning rate decay
weight_decay=1e-6, # the coeff of weight decay
global_grad_clip=5.0, # the global norm clip
n_epoch=50, # train epochs
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
@ -55,7 +72,7 @@ class DeepSpeech2Trainer(Trainer):
'train_loss': float(loss), 'train_loss': float(loss),
} }
msg += "train time: {:>.3f}s, ".format(iteration_time) msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.data.batch_size) msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items()) for k, v in losses_np.items())
logger.info(msg) logger.info(msg)
@ -102,8 +119,8 @@ class DeepSpeech2Trainer(Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
model = DeepSpeech2Model( model = DeepSpeech2Model(
feat_size=self.train_loader.dataset.feature_size, feat_size=self.train_loader.collate_fn.feature_size,
dict_size=self.train_loader.dataset.vocab_size, dict_size=self.train_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers, num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
@ -137,50 +154,73 @@ class DeepSpeech2Trainer(Trainer):
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()
config.defrost() config.defrost()
config.data.keep_transcription_text = False config.collator.keep_transcription_text = False
config.data.manifest = config.data.train_manifest config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config) train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest config.data.manifest = config.data.dev_manifest
config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config) dev_dataset = ManifestDataset.from_config(config)
if self.parallel: if self.parallel:
batch_sampler = SortagradDistributedBatchSampler( batch_sampler = SortagradDistributedBatchSampler(
train_dataset, train_dataset,
batch_size=config.data.batch_size, batch_size=config.collator.batch_size,
num_replicas=None, num_replicas=None,
rank=None, rank=None,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
sortagrad=config.data.sortagrad, sortagrad=config.collator.sortagrad,
shuffle_method=config.data.shuffle_method) shuffle_method=config.collator.shuffle_method)
else: else:
batch_sampler = SortagradBatchSampler( batch_sampler = SortagradBatchSampler(
train_dataset, train_dataset,
shuffle=True, shuffle=True,
batch_size=config.data.batch_size, batch_size=config.collator.batch_size,
drop_last=True, drop_last=True,
sortagrad=config.data.sortagrad, sortagrad=config.collator.sortagrad,
shuffle_method=config.data.shuffle_method) shuffle_method=config.collator.shuffle_method)
collate_fn_train = SpeechCollator.from_config(config)
collate_fn = SpeechCollator(keep_transcription_text=False) config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=collate_fn, collate_fn=collate_fn_train,
num_workers=config.data.num_workers) num_workers=config.collator.num_workers)
self.valid_loader = DataLoader( self.valid_loader = DataLoader(
dev_dataset, dev_dataset,
batch_size=config.data.batch_size, batch_size=config.collator.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=collate_fn) collate_fn=collate_fn_dev)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
class DeepSpeech2Tester(DeepSpeech2Trainer): class DeepSpeech2Tester(DeepSpeech2Trainer):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# testing config
default = CfgNode(
dict(
alpha=2.5, # Coef of LM for beam search.
beta=0.3, # Coef of WC for beam search.
cutoff_prob=1.0, # Cutoff probability for pruning.
cutoff_top_n=40, # Cutoff number for pruning.
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=500, # Beam search width.
batch_size=128, # decoding batch size
))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
@ -193,13 +233,19 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
trans.append(''.join([chr(i) for i in ids])) trans.append(''.join([chr(i) for i in ids]))
return trans return trans
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout = None): def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
vocab_list = self.test_loader.dataset.vocab_list vocab_list = self.test_loader.collate_fn.vocab_list
target_transcripts = self.ordid2token(texts, texts_len) target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
@ -215,7 +261,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cutoff_top_n=cfg.cutoff_top_n, cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch) num_processes=cfg.num_proc_bsearch)
for utt, target, result in zip(utts, target_transcripts, result_transcripts): for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
errors_sum += errors errors_sum += errors
len_refs += len_ref len_refs += len_ref
@ -245,7 +292,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
with open(self.args.result_file, 'w') as fout: with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts, texts_len, fout) metrics = self.compute_metrics(utts, audio, audio_len, texts,
texts_len, fout)
errors_sum += metrics['errors_sum'] errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs'] len_refs += metrics['len_refs']
num_ins += metrics['num_ins'] num_ins += metrics['num_ins']
@ -272,7 +320,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
infer_model = DeepSpeech2InferModel.from_pretrained( infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader.dataset, self.config, self.args.checkpoint_path) self.test_loader.dataset, self.config, self.args.checkpoint_path)
infer_model.eval() infer_model.eval()
feat_dim = self.test_loader.dataset.feature_size feat_dim = self.test_loader.collate_fn.feature_size
static_model = paddle.jit.to_static( static_model = paddle.jit.to_static(
infer_model, infer_model,
input_spec=[ input_spec=[
@ -308,8 +356,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def setup_model(self): def setup_model(self):
config = self.config config = self.config
model = DeepSpeech2Model( model = DeepSpeech2Model(
feat_size=self.test_loader.dataset.feature_size, feat_size=self.test_loader.collate_fn.feature_size,
dict_size=self.test_loader.dataset.vocab_size, dict_size=self.test_loader.collate_fn.vocab_size,
num_conv_layers=config.model.num_conv_layers, num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers, num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size, rnn_size=config.model.rnn_layer_size,
@ -324,8 +372,6 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
# return raw text # return raw text
config.data.manifest = config.data.test_manifest config.data.manifest = config.data.test_manifest
config.data.keep_transcription_text = True
config.data.augmentation_config = ""
# filter test examples, will cause less examples, but no mismatch with training # filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now. # and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second # config.data.min_input_len = 0.0 # second
@ -336,13 +382,15 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
# config.data.max_output_input_ratio = float('inf') # config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config) test_dataset = ManifestDataset.from_config(config)
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
# return text ord id # return text ord id
self.test_loader = DataLoader( self.test_loader = DataLoader(
test_dataset, test_dataset,
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True)) collate_fn=SpeechCollator.from_config(config))
logger.info("Setup test Dataloader!") logger.info("Setup test Dataloader!")
def setup_output_dir(self): def setup_output_dir(self):

@ -15,6 +15,7 @@ from yacs.config import CfgNode
from deepspeech.exps.u2.model import U2Tester from deepspeech.exps.u2.model import U2Tester
from deepspeech.exps.u2.model import U2Trainer from deepspeech.exps.u2.model import U2Trainer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.u2 import U2Model from deepspeech.models.u2 import U2Model
@ -22,6 +23,8 @@ _C = CfgNode()
_C.data = ManifestDataset.params() _C.data = ManifestDataset.params()
_C.collator = SpeechCollator.params()
_C.model = U2Model.params() _C.model = U2Model.params()
_C.training = U2Trainer.params() _C.training = U2Trainer.params()

@ -78,7 +78,8 @@ class U2Trainer(Trainer):
start = time.time() start = time.time()
utt, audio, audio_len, text, text_len = batch_data utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad loss /= train_conf.accum_grad
loss.backward() loss.backward()
@ -100,7 +101,7 @@ class U2Trainer(Trainer):
if (batch_index + 1) % train_conf.log_interval == 0: if (batch_index + 1) % train_conf.log_interval == 0:
msg += "train time: {:>.3f}s, ".format(iteration_time) msg += "train time: {:>.3f}s, ".format(iteration_time)
msg += "batch size: {}, ".format(self.config.data.batch_size) msg += "batch size: {}, ".format(self.config.collator.batch_size)
msg += "accum: {}, ".format(train_conf.accum_grad) msg += "accum: {}, ".format(train_conf.accum_grad)
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items()) for k, v in losses_np.items())
@ -121,7 +122,8 @@ class U2Trainer(Trainer):
total_loss = 0.0 total_loss = 0.0
for i, batch in enumerate(self.valid_loader): for i, batch in enumerate(self.valid_loader):
utt, audio, audio_len, text, text_len = batch utt, audio, audio_len, text, text_len = batch
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
if paddle.isfinite(loss): if paddle.isfinite(loss):
num_utts = batch[1].shape[0] num_utts = batch[1].shape[0]
num_seen_utts += num_utts num_seen_utts += num_utts
@ -211,51 +213,52 @@ class U2Trainer(Trainer):
def setup_dataloader(self): def setup_dataloader(self):
config = self.config.clone() config = self.config.clone()
config.defrost() config.defrost()
config.data.keep_transcription_text = False config.collator.keep_transcription_text = False
# train/valid dataset, return token ids # train/valid dataset, return token ids
config.data.manifest = config.data.train_manifest config.data.manifest = config.data.train_manifest
train_dataset = ManifestDataset.from_config(config) train_dataset = ManifestDataset.from_config(config)
config.data.manifest = config.data.dev_manifest config.data.manifest = config.data.dev_manifest
config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config) dev_dataset = ManifestDataset.from_config(config)
collate_fn = SpeechCollator(keep_transcription_text=False) collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
if self.parallel: if self.parallel:
batch_sampler = SortagradDistributedBatchSampler( batch_sampler = SortagradDistributedBatchSampler(
train_dataset, train_dataset,
batch_size=config.data.batch_size, batch_size=config.collator.batch_size,
num_replicas=None, num_replicas=None,
rank=None, rank=None,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
sortagrad=config.data.sortagrad, sortagrad=config.collator.sortagrad,
shuffle_method=config.data.shuffle_method) shuffle_method=config.collator.shuffle_method)
else: else:
batch_sampler = SortagradBatchSampler( batch_sampler = SortagradBatchSampler(
train_dataset, train_dataset,
shuffle=True, shuffle=True,
batch_size=config.data.batch_size, batch_size=config.collator.batch_size,
drop_last=True, drop_last=True,
sortagrad=config.data.sortagrad, sortagrad=config.collator.sortagrad,
shuffle_method=config.data.shuffle_method) shuffle_method=config.collator.shuffle_method)
self.train_loader = DataLoader( self.train_loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
collate_fn=collate_fn, collate_fn=collate_fn_train,
num_workers=config.data.num_workers, ) num_workers=config.collator.num_workers, )
self.valid_loader = DataLoader( self.valid_loader = DataLoader(
dev_dataset, dev_dataset,
batch_size=config.data.batch_size, batch_size=config.collator.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=collate_fn) collate_fn=collate_fn_dev)
# test dataset, return raw text # test dataset, return raw text
config.data.manifest = config.data.test_manifest config.data.manifest = config.data.test_manifest
config.data.keep_transcription_text = True
config.data.augmentation_config = ""
# filter test examples, will cause less examples, but no mismatch with training # filter test examples, will cause less examples, but no mismatch with training
# and can use large batch size , save training time, so filter test egs now. # and can use large batch size , save training time, so filter test egs now.
# config.data.min_input_len = 0.0 # second # config.data.min_input_len = 0.0 # second
@ -264,22 +267,25 @@ class U2Trainer(Trainer):
# config.data.max_output_len = float('inf') # tokens # config.data.max_output_len = float('inf') # tokens
# config.data.min_output_input_ratio = 0.00 # config.data.min_output_input_ratio = 0.00
# config.data.max_output_input_ratio = float('inf') # config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config) test_dataset = ManifestDataset.from_config(config)
# return text ord id # return text ord id
config.collator.keep_transcription_text = True
config.collator.augmentation_config = ""
self.test_loader = DataLoader( self.test_loader = DataLoader(
test_dataset, test_dataset,
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True)) collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test Dataloader!") logger.info("Setup train/valid/test Dataloader!")
def setup_model(self): def setup_model(self):
config = self.config config = self.config
model_conf = config.model model_conf = config.model
model_conf.defrost() model_conf.defrost()
model_conf.input_dim = self.train_loader.dataset.feature_size model_conf.input_dim = self.train_loader.collate_fn.feature_size
model_conf.output_dim = self.train_loader.dataset.vocab_size model_conf.output_dim = self.train_loader.collate_fn.vocab_size
model_conf.freeze() model_conf.freeze()
model = U2Model.from_config(model_conf) model = U2Model.from_config(model_conf)
@ -368,14 +374,20 @@ class U2Tester(U2Trainer):
trans.append(''.join([chr(i) for i in ids])) trans.append(''.join([chr(i) for i in ids]))
return trans return trans
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None): def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time() start_time = time.time()
text_feature = self.test_loader.dataset.text_feature text_feature = self.test_loader.collate_fn.text_feature
target_transcripts = self.ordid2token(texts, texts_len) target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
@ -395,7 +407,8 @@ class U2Tester(U2Trainer):
simulate_streaming=cfg.simulate_streaming) simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time decode_time = time.time() - start_time
for utt, target, result in zip(utts, target_transcripts, result_transcripts): for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
errors_sum += errors errors_sum += errors
len_refs += len_ref len_refs += len_ref
@ -423,7 +436,7 @@ class U2Tester(U2Trainer):
self.model.eval() self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.test_loader.dataset.stride_ms stride_ms = self.test_loader.collate_fn.stride_ms
error_rate_type = None error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0 num_frames = 0.0
@ -496,7 +509,7 @@ class U2Tester(U2Trainer):
infer_model = U2InferModel.from_pretrained(self.test_loader.dataset, infer_model = U2InferModel.from_pretrained(self.test_loader.dataset,
self.config.model.clone(), self.config.model.clone(),
self.args.checkpoint_path) self.args.checkpoint_path)
feat_dim = self.test_loader.dataset.feature_size feat_dim = self.test_loader.collate_fn.feature_size
input_spec = [ input_spec = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, feat_dim, None], shape=[None, feat_dim, None],

@ -107,7 +107,6 @@ class SpeechFeaturizer(object):
@property @property
def vocab_size(self): def vocab_size(self):
"""Return the vocabulary size. """Return the vocabulary size.
Returns: Returns:
int: Vocabulary size. int: Vocabulary size.
""" """
@ -116,7 +115,6 @@ class SpeechFeaturizer(object):
@property @property
def vocab_list(self): def vocab_list(self):
"""Return the vocabulary in list. """Return the vocabulary in list.
Returns: Returns:
List[str]: List[str]:
""" """
@ -125,7 +123,6 @@ class SpeechFeaturizer(object):
@property @property
def vocab_dict(self): def vocab_dict(self):
"""Return the vocabulary in dict. """Return the vocabulary in dict.
Returns: Returns:
Dict[str, int]: Dict[str, int]:
""" """
@ -134,7 +131,6 @@ class SpeechFeaturizer(object):
@property @property
def feature_size(self): def feature_size(self):
"""Return the audio feature size. """Return the audio feature size.
Returns: Returns:
int: audio feature size. int: audio feature size.
""" """
@ -143,7 +139,6 @@ class SpeechFeaturizer(object):
@property @property
def stride_ms(self): def stride_ms(self):
"""time length in `ms` unit per frame """time length in `ms` unit per frame
Returns: Returns:
float: time(ms)/frame float: time(ms)/frame
""" """
@ -152,7 +147,6 @@ class SpeechFeaturizer(object):
@property @property
def text_feature(self): def text_feature(self):
"""Return the text feature object. """Return the text feature object.
Returns: Returns:
TextFeaturizer: object. TextFeaturizer: object.
""" """

@ -11,8 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io
from collections import namedtuple
from typing import Optional
import numpy as np import numpy as np
from yacs.config import CfgNode
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment
from deepspeech.frontend.utility import IGNORE_ID from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.io.utility import pad_sequence from deepspeech.io.utility import pad_sequence
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
@ -21,17 +30,220 @@ __all__ = ["SpeechCollator"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
# namedtupe need global for pickle.
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
class SpeechCollator(): class SpeechCollator():
def __init__(self, keep_transcription_text=True): @classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
augmentation_config="",
random_seed=0,
mean_std_filepath="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=False))
if config is not None:
config.merge_from_other_cfg(default)
return default
@classmethod
def from_config(cls, config):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
""" """
Padding audio features with zeros to make them have the same shape (or assert 'augmentation_config' in config.collator
a user-defined shape) within one bach. assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.collator
assert 'vocab_filepath' in config.collator
assert 'specgram_type' in config.collator
assert 'n_fft' in config.collator
assert config.collator
if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.collator.augmentation_config:
aug_file = io.open(
config.collator.augmentation_config,
mode='r',
encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.collator.augmentation_config
assert isinstance(aug_file, io.StringIO)
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
mean_std_filepath=config.collator.mean_std_filepath,
unit_type=config.collator.unit_type,
vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
specgram_type=config.collator.specgram_type,
feat_dim=config.collator.feat_dim,
delta_delta=config.collator.delta_delta,
stride_ms=config.collator.stride_ms,
window_ms=config.collator.window_ms,
n_fft=config.collator.n_fft,
max_freq=config.collator.max_freq,
target_sample_rate=config.collator.target_sample_rate,
use_dB_normalization=config.collator.use_dB_normalization,
target_dB=config.collator.target_dB,
dither=config.collator.dither,
keep_transcription_text=config.collator.keep_transcription_text)
return speech_collator
def __init__(
self,
aug_file,
mean_std_filepath,
vocab_filepath,
spm_model_prefix,
random_seed=0,
unit_type="char",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0,
keep_transcription_text=True):
"""SpeechCollator Collator
Args:
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
target_dB (int, optional): target dB. Defaults to -20.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
if ``keep_transcription_text`` is False, text is token ids else is raw string. if ``keep_transcription_text`` is False, text is token ids else is raw string.
Do augmentations
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch.
""" """
self._keep_transcription_text = keep_transcription_text self._keep_transcription_text = keep_transcription_text
self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=aug_file.read(), random_seed=random_seed)
self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None
self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate
self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=stride_ms,
window_ms=window_ms,
n_fft=n_fft,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB,
dither=dither)
def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def _subfile_from_tar(self, file):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if 'tar2info' not in self._local_data.__dict__:
self._local_data.tar2info = {}
if 'tar2object' not in self._local_data.__dict__:
self._local_data.tar2object = {}
if tarpath not in self._local_data.tar2info:
object, infoes = self._parse_tar(tarpath)
self._local_data.tar2info[tarpath] = infoes
self._local_data.tar2object[tarpath] = object
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
def process_utterance(self, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param transcript: Transcription text.
:type transcript: str
:return: Tuple of audio feature tensor and data of transcription part,
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), transcript)
else:
speech_segment = SpeechSegment.from_file(audio_file, transcript)
# audio augment
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
if self._normalizer:
specgram = self._normalizer.apply(specgram)
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
return specgram, transcript_part
def __call__(self, batch): def __call__(self, batch):
"""batch examples """batch examples
@ -53,6 +265,7 @@ class SpeechCollator():
text_lens = [] text_lens = []
utts = [] utts = []
for utt, audio, text in batch: for utt, audio, text in batch:
audio, text = self.process_utterance(audio, text)
#utt #utt
utts.append(utt) utts.append(utt)
# audio # audio
@ -79,3 +292,31 @@ class SpeechCollator():
texts, padding_value=IGNORE_ID).astype(np.int64) texts, padding_value=IGNORE_ID).astype(np.int64)
text_lens = np.array(text_lens).astype(np.int64) text_lens = np.array(text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, padded_texts, text_lens return utts, padded_audios, audio_lens, padded_texts, text_lens
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms

@ -11,20 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io
import tarfile
import time
from collections import namedtuple
from typing import Optional from typing import Optional
import numpy as np
from paddle.io import Dataset from paddle.io import Dataset
from yacs.config import CfgNode from yacs.config import CfgNode
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment
from deepspeech.frontend.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
@ -34,49 +25,19 @@ __all__ = [
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
# namedtupe need global for pickle.
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
class ManifestDataset(Dataset): class ManifestDataset(Dataset):
@classmethod @classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode: def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode( default = CfgNode(
dict( dict(
train_manifest="",
dev_manifest="",
test_manifest="",
manifest="", manifest="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
mean_std_filepath="",
augmentation_config="",
max_input_len=27.0, max_input_len=27.0,
min_input_len=0.0, min_input_len=0.0,
max_output_len=float('inf'), max_output_len=float('inf'),
min_output_len=0.0, min_output_len=0.0,
max_output_input_ratio=float('inf'), max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0, min_output_input_ratio=0.0, ))
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
raw_wav=True, # use raw_wav or kaldi feature
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
dither=1.0, # feature dither
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
random_seed=0,
keep_transcription_text=False,
batch_size=32, # batch size
num_workers=0, # data loader workers
sortagrad=False, # sorted in first epoch when True
shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle'
))
if config is not None: if config is not None:
config.merge_from_other_cfg(default) config.merge_from_other_cfg(default)
@ -94,128 +55,38 @@ class ManifestDataset(Dataset):
""" """
assert 'manifest' in config.data assert 'manifest' in config.data
assert config.data.manifest assert config.data.manifest
assert 'keep_transcription_text' in config.data
if isinstance(config.data.augmentation_config, (str, bytes)):
if config.data.augmentation_config:
aug_file = io.open(
config.data.augmentation_config, mode='r', encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.data.augmentation_config
assert isinstance(aug_file, io.StringIO)
dataset = cls( dataset = cls(
manifest_path=config.data.manifest, manifest_path=config.data.manifest,
unit_type=config.data.unit_type,
vocab_filepath=config.data.vocab_filepath,
mean_std_filepath=config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config=aug_file.read(),
max_input_len=config.data.max_input_len, max_input_len=config.data.max_input_len,
min_input_len=config.data.min_input_len, min_input_len=config.data.min_input_len,
max_output_len=config.data.max_output_len, max_output_len=config.data.max_output_len,
min_output_len=config.data.min_output_len, min_output_len=config.data.min_output_len,
max_output_input_ratio=config.data.max_output_input_ratio, max_output_input_ratio=config.data.max_output_input_ratio,
min_output_input_ratio=config.data.min_output_input_ratio, min_output_input_ratio=config.data.min_output_input_ratio, )
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
feat_dim=config.data.feat_dim,
delta_delta=config.data.delta_delta,
dither=config.data.dither,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=config.data.keep_transcription_text)
return dataset return dataset
def __init__(self, def __init__(self,
manifest_path, manifest_path,
unit_type,
vocab_filepath,
mean_std_filepath,
spm_model_prefix=None,
augmentation_config='{}',
max_input_len=float('inf'), max_input_len=float('inf'),
min_input_len=0.0, min_input_len=0.0,
max_output_len=float('inf'), max_output_len=float('inf'),
min_output_len=0.0, min_output_len=0.0,
max_output_input_ratio=float('inf'), max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0, min_output_input_ratio=0.0):
stride_ms=10.0,
window_ms=20.0,
n_fft=None,
max_freq=None,
target_sample_rate=16000,
specgram_type='linear',
feat_dim=None,
delta_delta=False,
dither=1.0,
use_dB_normalization=True,
target_dB=-20,
random_seed=0,
keep_transcription_text=False):
"""Manifest Dataset """Manifest Dataset
Args: Args:
manifest_path (str): manifest josn file path manifest_path (str): manifest josn file path
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
target_dB (int, optional): target dB. Defaults to -20.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
""" """
super().__init__() super().__init__()
self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate
self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=augmentation_config, random_seed=random_seed)
self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=stride_ms,
window_ms=window_ms,
n_fft=n_fft,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB,
dither=dither)
self._rng = np.random.RandomState(random_seed)
self._keep_transcription_text = keep_transcription_text
# for caching tar files info
self._local_data = TarLocalData(tar2info={}, tar2object={})
# read manifest # read manifest
self._manifest = read_manifest( self._manifest = read_manifest(
@ -228,125 +99,9 @@ class ManifestDataset(Dataset):
min_output_input_ratio=min_output_input_ratio) min_output_input_ratio=min_output_input_ratio)
self._manifest.sort(key=lambda x: x["feat_shape"][0]) self._manifest.sort(key=lambda x: x["feat_shape"][0])
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def _subfile_from_tar(self, file):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if 'tar2info' not in self._local_data.__dict__:
self._local_data.tar2info = {}
if 'tar2object' not in self._local_data.__dict__:
self._local_data.tar2object = {}
if tarpath not in self._local_data.tar2info:
object, infoes = self._parse_tar(tarpath)
self._local_data.tar2info[tarpath] = infoes
self._local_data.tar2object[tarpath] = object
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
def process_utterance(self, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param transcript: Transcription text.
:type transcript: str
:return: Tuple of audio feature tensor and data of transcription part,
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
start_time = time.time()
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), transcript)
else:
speech_segment = SpeechSegment.from_file(audio_file, transcript)
load_wav_time = time.time() - start_time
#logger.debug(f"load wav time: {load_wav_time}")
# audio augment
start_time = time.time()
self._augmentation_pipeline.transform_audio(speech_segment)
audio_aug_time = time.time() - start_time
#logger.debug(f"audio augmentation time: {audio_aug_time}")
start_time = time.time()
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
if self._normalizer:
specgram = self._normalizer.apply(specgram)
feature_time = time.time() - start_time
#logger.debug(f"audio & test feature time: {feature_time}")
# specgram augment
start_time = time.time()
specgram = self._augmentation_pipeline.transform_feature(specgram)
feature_aug_time = time.time() - start_time
#logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return specgram, transcript_part
def _instance_reader_creator(self, manifest):
"""
Instance reader creator. Create a callable function to produce
instances of data.
Instance: a tuple of ndarray of audio spectrogram and a list of
token indices for transcript.
"""
def reader():
for instance in manifest:
inst = self.process_utterance(instance["feat"],
instance["text"])
yield inst
return reader
def __len__(self): def __len__(self):
return len(self._manifest) return len(self._manifest)
def __getitem__(self, idx): def __getitem__(self, idx):
instance = self._manifest[idx] instance = self._manifest[idx]
feat, text =self.process_utterance(instance["feat"], return instance["utt"], instance["feat"], instance["text"]
instance["text"])
return instance["utt"], feat, text

@ -905,7 +905,6 @@ class U2InferModel(U2Model):
def __init__(self, configs: dict): def __init__(self, configs: dict):
super().__init__(configs) super().__init__(configs)
def forward(self, def forward(self,
feats, feats,
feats_lengths, feats_lengths,

@ -5,29 +5,34 @@ data:
test_manifest: data/manifest.test test_manifest: data/manifest.test
mean_std_filepath: data/mean_std.json mean_std_filepath: data/mean_std.json
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
batch_size: 64 # one gpu
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 # second max_input_len: 27.0 # second
min_output_len: 0.0 min_output_len: 0.0
max_output_len: .inf max_output_len: .inf
min_output_input_ratio: 0.00 min_output_input_ratio: 0.00
max_output_input_ratio: .inf max_output_input_ratio: .inf
collator:
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear specgram_type: linear
target_sample_rate: 16000 feat_dim:
max_freq: None delta_delta: False
n_fft: None
stride_ms: 10.0 stride_ms: 10.0
window_ms: 20.0 window_ms: 20.0
delta_delta: False n_fft: None
dither: 1.0 max_freq: None
target_sample_rate: 16000
use_dB_normalization: True use_dB_normalization: True
target_dB: -20 target_dB: -20
random_seed: 0 dither: 1.0
keep_transcription_text: False keep_transcription_text: False
sortagrad: True sortagrad: True
shuffle_method: batch_shuffle shuffle_method: batch_shuffle
num_workers: 0 num_workers: 0
batch_size: 64 # one gpu
model: model:
num_conv_layers: 2 num_conv_layers: 2

@ -3,17 +3,20 @@ data:
train_manifest: data/manifest.train train_manifest: data/manifest.train
dev_manifest: data/manifest.dev dev_manifest: data/manifest.dev
test_manifest: data/manifest.test test_manifest: data/manifest.test
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
augmentation_config: conf/augmentation.json
batch_size: 64
min_input_len: 0.5 min_input_len: 0.5
max_input_len: 20.0 # second max_input_len: 20.0 # second
min_output_len: 0.0 min_output_len: 0.0
max_output_len: 400.0 max_output_len: 400.0
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
max_output_input_ratio: 10.0 max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'char'
spm_model_prefix: ''
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80 feat_dim: 80
@ -32,7 +35,6 @@ data:
shuffle_method: batch_shuffle shuffle_method: batch_shuffle
num_workers: 2 num_workers: 2
# network architecture # network architecture
model: model:
cmvn_file: "data/mean_std.json" cmvn_file: "data/mean_std.json"

@ -3,31 +3,37 @@ data:
train_manifest: data/manifest.tiny train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny test_manifest: data/manifest.tiny
mean_std_filepath: data/mean_std.json
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
batch_size: 4
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 max_input_len: 27.0
min_output_len: 0.0 min_output_len: 0.0
max_output_len: 400.0 max_output_len: 400.0
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
max_output_input_ratio: 10.0 max_output_input_ratio: 10.0
collator:
mean_std_filepath: data/mean_std.json
unit_type: char
vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: linear specgram_type: linear
target_sample_rate: 16000 feat_dim:
max_freq: None delta_delta: False
n_fft: None
stride_ms: 10.0 stride_ms: 10.0
window_ms: 20.0 window_ms: 20.0
delta_delta: False n_fft: None
dither: 1.0 max_freq: None
target_sample_rate: 16000
use_dB_normalization: True use_dB_normalization: True
target_dB: -20 target_dB: -20
random_seed: 0 dither: 1.0
keep_transcription_text: False keep_transcription_text: False
sortagrad: True sortagrad: True
shuffle_method: batch_shuffle shuffle_method: batch_shuffle
num_workers: 0 num_workers: 0
batch_size: 4
model: model:
num_conv_layers: 2 num_conv_layers: 2
@ -37,7 +43,7 @@ model:
share_rnn_weights: True share_rnn_weights: True
training: training:
n_epoch: 20 n_epoch: 24
lr: 1e-5 lr: 1e-5
lr_decay: 1.0 lr_decay: 1.0
weight_decay: 1e-06 weight_decay: 1e-06

@ -3,35 +3,37 @@ data:
train_manifest: data/manifest.tiny train_manifest: data/manifest.tiny
dev_manifest: data/manifest.tiny dev_manifest: data/manifest.tiny
test_manifest: data/manifest.tiny test_manifest: data/manifest.tiny
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 4
min_input_len: 0.5 # second min_input_len: 0.5 # second
max_input_len: 20.0 # second max_input_len: 20.0 # second
min_output_len: 0.0 # tokens min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
max_output_input_ratio: 10.0 max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank collator:
vocab_filepath: data/vocab.txt
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
random_seed: 0
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
specgram_type: fbank
feat_dim: 80 feat_dim: 80
delta_delta: False delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0 stride_ms: 10.0
window_ms: 25.0 window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True use_dB_normalization: True
target_dB: -20 target_dB: -20
random_seed: 0 dither: 1.0
keep_transcription_text: False keep_transcription_text: False
batch_size: 4
sortagrad: True sortagrad: True
shuffle_method: batch_shuffle shuffle_method: batch_shuffle
num_workers: 2 num_workers: 0 #2
raw_wav: True # use raw_wav or kaldi feature
# network architecture # network architecture
@ -70,7 +72,7 @@ model:
training: training:
n_epoch: 2 n_epoch: 21
accum_grad: 1 accum_grad: 1
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam

Loading…
Cancel
Save