@ -44,27 +44,11 @@ logger = Log(__name__).getlog()
class DeepSpeech2Trainer ( Trainer ) :
class DeepSpeech2Trainer ( Trainer ) :
@classmethod
def params ( cls , config : Optional [ CfgNode ] = None ) - > CfgNode :
# training config
default = CfgNode (
dict (
lr = 5e-4 , # learning rate
lr_decay = 1.0 , # learning rate decay
weight_decay = 1e-6 , # the coeff of weight decay
global_grad_clip = 5.0 , # the global norm clip
n_epoch = 50 , # train epochs
) )
if config is not None :
config . merge_from_other_cfg ( default )
return default
def __init__ ( self , config , args ) :
def __init__ ( self , config , args ) :
super ( ) . __init__ ( config , args )
super ( ) . __init__ ( config , args )
def train_batch ( self , batch_index , batch_data , msg ) :
def train_batch ( self , batch_index , batch_data , msg ) :
train_conf = self . config . training
train_conf = self . config
start = time . time ( )
start = time . time ( )
# forward
# forward
@ -98,7 +82,7 @@ class DeepSpeech2Trainer(Trainer):
iteration_time = time . time ( ) - start
iteration_time = time . time ( ) - start
msg + = " train time: {:>.3f} s, " . format ( iteration_time )
msg + = " train time: {:>.3f} s, " . format ( iteration_time )
msg + = " batch size: {} , " . format ( self . config . collator. batch_size)
msg + = " batch size: {} , " . format ( self . config . batch_size)
msg + = " accum: {} , " . format ( train_conf . accum_grad )
msg + = " accum: {} , " . format ( train_conf . accum_grad )
msg + = ' , ' . join ( ' {} : {:>.6f} ' . format ( k , v )
msg + = ' , ' . join ( ' {} : {:>.6f} ' . format ( k , v )
for k , v in losses_np . items ( ) )
for k , v in losses_np . items ( ) )
@ -126,7 +110,7 @@ class DeepSpeech2Trainer(Trainer):
total_loss + = float ( loss ) * num_utts
total_loss + = float ( loss ) * num_utts
valid_losses [ ' val_loss ' ] . append ( float ( loss ) )
valid_losses [ ' val_loss ' ] . append ( float ( loss ) )
if ( i + 1 ) % self . config . training. log_interval == 0 :
if ( i + 1 ) % self . config . log_interval == 0 :
valid_dump = { k : np . mean ( v ) for k , v in valid_losses . items ( ) }
valid_dump = { k : np . mean ( v ) for k , v in valid_losses . items ( ) }
valid_dump [ ' val_history_loss ' ] = total_loss / num_seen_utts
valid_dump [ ' val_history_loss ' ] = total_loss / num_seen_utts
@ -146,15 +130,15 @@ class DeepSpeech2Trainer(Trainer):
def setup_model ( self ) :
def setup_model ( self ) :
config = self . config . clone ( )
config = self . config . clone ( )
config . defrost ( )
config . defrost ( )
config . model. feat_size = self . train_loader . collate_fn . feature_size
config . feat_size = self . train_loader . collate_fn . feature_size
#config. model. dict_size = self.train_loader.collate_fn.vocab_size
#config. dict_size = self.train_loader.collate_fn.vocab_size
config . model. dict_size = len ( self . train_loader . collate_fn . vocab_list )
config . dict_size = len ( self . train_loader . collate_fn . vocab_list )
config . freeze ( )
config . freeze ( )
if self . args . model_type == ' offline ' :
if self . args . model_type == ' offline ' :
model = DeepSpeech2Model . from_config ( config . model )
model = DeepSpeech2Model . from_config ( config )
elif self . args . model_type == ' online ' :
elif self . args . model_type == ' online ' :
model = DeepSpeech2ModelOnline . from_config ( config . model )
model = DeepSpeech2ModelOnline . from_config ( config )
else :
else :
raise Exception ( " wrong model type " )
raise Exception ( " wrong model type " )
if self . parallel :
if self . parallel :
@ -163,17 +147,13 @@ class DeepSpeech2Trainer(Trainer):
logger . info ( f " { model } " )
logger . info ( f " { model } " )
layer_tools . print_params ( model , logger . info )
layer_tools . print_params ( model , logger . info )
grad_clip = ClipGradByGlobalNormWithLog (
grad_clip = ClipGradByGlobalNormWithLog ( config . global_grad_clip )
config . training . global_grad_clip )
lr_scheduler = paddle . optimizer . lr . ExponentialDecay (
lr_scheduler = paddle . optimizer . lr . ExponentialDecay (
learning_rate = config . training . lr ,
learning_rate = config . lr , gamma = config . lr_decay , verbose = True )
gamma = config . training . lr_decay ,
verbose = True )
optimizer = paddle . optimizer . Adam (
optimizer = paddle . optimizer . Adam (
learning_rate = lr_scheduler ,
learning_rate = lr_scheduler ,
parameters = model . parameters ( ) ,
parameters = model . parameters ( ) ,
weight_decay = paddle . regularizer . L2Decay (
weight_decay = paddle . regularizer . L2Decay ( config . weight_decay ) ,
config . training . weight_decay ) ,
grad_clip = grad_clip )
grad_clip = grad_clip )
self . model = model
self . model = model
@ -184,59 +164,59 @@ class DeepSpeech2Trainer(Trainer):
def setup_dataloader ( self ) :
def setup_dataloader ( self ) :
config = self . config . clone ( )
config = self . config . clone ( )
config . defrost ( )
config . defrost ( )
config . collator. keep_transcription_text = False
config . keep_transcription_text = False
config . data. manifest = config . data . train_manifest
config . manifest = config . train_manifest
train_dataset = ManifestDataset . from_config ( config )
train_dataset = ManifestDataset . from_config ( config )
config . data. manifest = config . data . dev_manifest
config . manifest = config . dev_manifest
dev_dataset = ManifestDataset . from_config ( config )
dev_dataset = ManifestDataset . from_config ( config )
config . data. manifest = config . data . test_manifest
config . manifest = config . test_manifest
test_dataset = ManifestDataset . from_config ( config )
test_dataset = ManifestDataset . from_config ( config )
if self . parallel :
if self . parallel :
batch_sampler = SortagradDistributedBatchSampler (
batch_sampler = SortagradDistributedBatchSampler (
train_dataset ,
train_dataset ,
batch_size = config . collator. batch_size,
batch_size = config . batch_size,
num_replicas = None ,
num_replicas = None ,
rank = None ,
rank = None ,
shuffle = True ,
shuffle = True ,
drop_last = True ,
drop_last = True ,
sortagrad = config . collator. sortagrad,
sortagrad = config . sortagrad,
shuffle_method = config . collator. shuffle_method)
shuffle_method = config . shuffle_method)
else :
else :
batch_sampler = SortagradBatchSampler (
batch_sampler = SortagradBatchSampler (
train_dataset ,
train_dataset ,
shuffle = True ,
shuffle = True ,
batch_size = config . collator. batch_size,
batch_size = config . batch_size,
drop_last = True ,
drop_last = True ,
sortagrad = config . collator. sortagrad,
sortagrad = config . sortagrad,
shuffle_method = config . collator. shuffle_method)
shuffle_method = config . shuffle_method)
collate_fn_train = SpeechCollator . from_config ( config )
collate_fn_train = SpeechCollator . from_config ( config )
config . collator. augmentation_config = " "
config . augmentation_config = " "
collate_fn_dev = SpeechCollator . from_config ( config )
collate_fn_dev = SpeechCollator . from_config ( config )
config . collator. keep_transcription_text = True
config . keep_transcription_text = True
config . collator. augmentation_config = " "
config . augmentation_config = " "
collate_fn_test = SpeechCollator . from_config ( config )
collate_fn_test = SpeechCollator . from_config ( config )
self . train_loader = DataLoader (
self . train_loader = DataLoader (
train_dataset ,
train_dataset ,
batch_sampler = batch_sampler ,
batch_sampler = batch_sampler ,
collate_fn = collate_fn_train ,
collate_fn = collate_fn_train ,
num_workers = config . collator. num_workers)
num_workers = config . num_workers)
self . valid_loader = DataLoader (
self . valid_loader = DataLoader (
dev_dataset ,
dev_dataset ,
batch_size = config . collator. batch_size,
batch_size = config . batch_size,
shuffle = False ,
shuffle = False ,
drop_last = False ,
drop_last = False ,
collate_fn = collate_fn_dev )
collate_fn = collate_fn_dev )
self . test_loader = DataLoader (
self . test_loader = DataLoader (
test_dataset ,
test_dataset ,
batch_size = config . decod ing. batch_size,
batch_size = config . decod e. decode_ batch_size,
shuffle = False ,
shuffle = False ,
drop_last = False ,
drop_last = False ,
collate_fn = collate_fn_test )
collate_fn = collate_fn_test )
@ -250,31 +230,10 @@ class DeepSpeech2Trainer(Trainer):
class DeepSpeech2Tester ( DeepSpeech2Trainer ) :
class DeepSpeech2Tester ( DeepSpeech2Trainer ) :
@classmethod
def params ( cls , config : Optional [ CfgNode ] = None ) - > CfgNode :
# testing config
default = CfgNode (
dict (
alpha = 2.5 , # Coef of LM for beam search.
beta = 0.3 , # Coef of WC for beam search.
cutoff_prob = 1.0 , # Cutoff probability for pruning.
cutoff_top_n = 40 , # Cutoff number for pruning.
lang_model_path = ' models/lm/common_crawl_00.prune01111.trie.klm ' , # Filepath for language model.
decoding_method = ' ctc_beam_search ' , # Decoding method. Options: ctc_beam_search, ctc_greedy
error_rate_type = ' wer ' , # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch = 8 , # # of CPUs for beam search.
beam_size = 500 , # Beam search width.
batch_size = 128 , # decoding batch size
) )
if config is not None :
config . merge_from_other_cfg ( default )
return default
def __init__ ( self , config , args ) :
def __init__ ( self , config , args ) :
self . _text_featurizer = TextFeaturizer (
self . _text_featurizer = TextFeaturizer (
unit_type = config . collator. unit_type, vocab _filepath = None )
unit_type = config . unit_type , vocab = None )
super ( ) . __init__ ( config , args )
super ( ) . __init__ ( config , args )
def ordid2token ( self , texts , texts_len ) :
def ordid2token ( self , texts , texts_len ) :
@ -293,7 +252,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
texts ,
texts ,
texts_len ,
texts_len ,
fout = None ) :
fout = None ) :
cfg = self . config . decod ing
cfg = self . config . decod e
errors_sum , len_refs , num_ins = 0.0 , 0 , 0
errors_sum , len_refs , num_ins = 0.0 , 0 , 0
errors_func = error_rate . char_errors if cfg . error_rate_type == ' cer ' else error_rate . word_errors
errors_func = error_rate . char_errors if cfg . error_rate_type == ' cer ' else error_rate . word_errors
error_rate_func = error_rate . cer if cfg . error_rate_type == ' cer ' else error_rate . wer
error_rate_func = error_rate . cer if cfg . error_rate_type == ' cer ' else error_rate . wer
@ -399,31 +358,3 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self . export ( )
self . export ( )
except KeyboardInterrupt :
except KeyboardInterrupt :
exit ( - 1 )
exit ( - 1 )
def setup ( self ) :
""" Setup the experiment.
"""
paddle . set_device ( ' gpu ' if self . args . ngpu > 0 else ' cpu ' )
self . setup_output_dir ( )
self . setup_checkpointer ( )
self . setup_dataloader ( )
self . setup_model ( )
self . iteration = 0
self . epoch = 0
def setup_output_dir ( self ) :
""" Create a directory used for output.
"""
# output dir
if self . args . output :
output_dir = Path ( self . args . output ) . expanduser ( )
output_dir . mkdir ( parents = True , exist_ok = True )
else :
output_dir = Path (
self . args . checkpoint_path ) . expanduser ( ) . parent . parent
output_dir . mkdir ( parents = True , exist_ok = True )
self . output_dir = output_dir