@ -45,33 +45,6 @@ logger = Log(__name__).getlog()
class U2STTrainer ( Trainer ) :
class U2STTrainer ( Trainer ) :
@classmethod
def params ( cls , config : Optional [ CfgNode ] = None ) - > CfgNode :
# training config
default = CfgNode (
dict (
n_epoch = 50 , # train epochs
log_interval = 100 , # steps
accum_grad = 1 , # accum grad by # steps
global_grad_clip = 5.0 , # the global norm clip
) )
default . optim = ' adam '
default . optim_conf = CfgNode (
dict (
lr = 5e-4 , # learning rate
weight_decay = 1e-6 , # the coeff of weight decay
) )
default . scheduler = ' warmuplr '
default . scheduler_conf = CfgNode (
dict (
warmup_steps = 25000 ,
lr_decay = 1.0 , # learning rate decay
) )
if config is not None :
config . merge_from_other_cfg ( default )
return default
def __init__ ( self , config , args ) :
def __init__ ( self , config , args ) :
super ( ) . __init__ ( config , args )
super ( ) . __init__ ( config , args )
@ -127,7 +100,7 @@ class U2STTrainer(Trainer):
for k , v in losses_np . items ( ) :
for k , v in losses_np . items ( ) :
report ( k , v )
report ( k , v )
report ( " batch_size " , self . config . collator. batch_size)
report ( " batch_size " , self . config . batch_size)
report ( " accum " , train_conf . accum_grad )
report ( " accum " , train_conf . accum_grad )
report ( " step_cost " , iteration_time )
report ( " step_cost " , iteration_time )
@ -236,7 +209,7 @@ class U2STTrainer(Trainer):
msg + = " , "
msg + = " , "
msg = msg [ : - 1 ] # remove the last ","
msg = msg [ : - 1 ] # remove the last ","
if ( batch_index + 1
if ( batch_index + 1
) % self . config . training. log_interval == 0 :
) % self . config . log_interval == 0 :
logger . info ( msg )
logger . info ( msg )
except Exception as e :
except Exception as e :
logger . error ( e )
logger . error ( e )
@ -287,7 +260,7 @@ class U2STTrainer(Trainer):
batch_frames_in = 0 ,
batch_frames_in = 0 ,
batch_frames_out = 0 ,
batch_frames_out = 0 ,
batch_frames_inout = 0 ,
batch_frames_inout = 0 ,
preprocess_conf = config . augmentation _config, # aug will be off when train_mode=False
preprocess_conf = config . preprocess _config, # aug will be off when train_mode=False
n_iter_processes = config . num_workers ,
n_iter_processes = config . num_workers ,
subsampling_factor = 1 ,
subsampling_factor = 1 ,
load_aux_output = load_transcript ,
load_aux_output = load_transcript ,
@ -308,7 +281,7 @@ class U2STTrainer(Trainer):
batch_frames_in = 0 ,
batch_frames_in = 0 ,
batch_frames_out = 0 ,
batch_frames_out = 0 ,
batch_frames_inout = 0 ,
batch_frames_inout = 0 ,
preprocess_conf = config . augmentation _config, # aug will be off when train_mode=False
preprocess_conf = config . preprocess _config, # aug will be off when train_mode=False
n_iter_processes = config . num_workers ,
n_iter_processes = config . num_workers ,
subsampling_factor = 1 ,
subsampling_factor = 1 ,
load_aux_output = load_transcript ,
load_aux_output = load_transcript ,
@ -319,7 +292,7 @@ class U2STTrainer(Trainer):
# test dataset, return raw text
# test dataset, return raw text
decode_batch_size = config . get ( ' decode ' , dict ( ) ) . get ( ' decode_batch_size ' , 1 )
decode_batch_size = config . get ( ' decode ' , dict ( ) ) . get ( ' decode_batch_size ' , 1 )
self . test_loader = BatchDataLoader (
self . test_loader = BatchDataLoader (
json_file = config . data. test_manifest,
json_file = config . test_manifest,
train_mode = False ,
train_mode = False ,
sortagrad = False ,
sortagrad = False ,
batch_size = decode_batch_size ,
batch_size = decode_batch_size ,
@ -332,7 +305,7 @@ class U2STTrainer(Trainer):
batch_frames_in = 0 ,
batch_frames_in = 0 ,
batch_frames_out = 0 ,
batch_frames_out = 0 ,
batch_frames_inout = 0 ,
batch_frames_inout = 0 ,
preprocess_conf = config . augmentation _config, # aug will be off when train_mode=False
preprocess_conf = config . preprocess _config, # aug will be off when train_mode=False
n_iter_processes = config . num_workers ,
n_iter_processes = config . num_workers ,
subsampling_factor = 1 ,
subsampling_factor = 1 ,
num_encs = 1 ,
num_encs = 1 ,
@ -379,7 +352,7 @@ class U2STTrainer(Trainer):
config ,
config ,
parameters ,
parameters ,
lr_scheduler = None , ) :
lr_scheduler = None , ) :
train_config = config . training
train_config = config
optim_type = train_config . optim
optim_type = train_config . optim
optim_conf = train_config . optim_conf
optim_conf = train_config . optim_conf
scheduler_type = train_config . scheduler
scheduler_type = train_config . scheduler
@ -405,41 +378,12 @@ class U2STTrainer(Trainer):
class U2STTester ( U2STTrainer ) :
class U2STTester ( U2STTrainer ) :
@classmethod
def params ( cls , config : Optional [ CfgNode ] = None ) - > CfgNode :
# decoding config
default = CfgNode (
dict (
alpha = 2.5 , # Coef of LM for beam search.
beta = 0.3 , # Coef of WC for beam search.
cutoff_prob = 1.0 , # Cutoff probability for pruning.
cutoff_top_n = 40 , # Cutoff number for pruning.
lang_model_path = ' models/lm/common_crawl_00.prune01111.trie.klm ' , # Filepath for language model.
decoding_method = ' attention ' , # Decoding method. Options: 'attention', 'ctc_greedy_search',
# 'ctc_prefix_beam_search', 'attention_rescoring'
error_rate_type = ' bleu ' , # Error rate type for evaluation. Options `bleu`, 'char_bleu'
num_proc_bsearch = 8 , # # of CPUs for beam search.
beam_size = 10 , # Beam search width.
batch_size = 16 , # decoding batch size
ctc_weight = 0.0 , # ctc weight for attention rescoring decode mode.
decoding_chunk_size = - 1 , # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks = - 1 , # number of left chunks for decoding. Defaults to -1.
simulate_streaming = False , # simulate streaming inference. Defaults to False.
) )
if config is not None :
config . merge_from_other_cfg ( default )
return default
def __init__ ( self , config , args ) :
def __init__ ( self , config , args ) :
super ( ) . __init__ ( config , args )
super ( ) . __init__ ( config , args )
self . text_feature = TextFeaturizer (
self . text_feature = TextFeaturizer (
unit_type = self . config . collator. unit_type,
unit_type = self . config . unit_type ,
vocab _filepath = self . config . collator . vocab_filepath ,
vocab = self . config . vocab_filepath ,
spm_model_prefix = self . config . collator. spm_model_prefix)
spm_model_prefix = self . config . spm_model_prefix )
self . vocab_list = self . text_feature . vocab_list
self . vocab_list = self . text_feature . vocab_list
def id2token ( self , texts , texts_len , text_feature ) :
def id2token ( self , texts , texts_len , text_feature ) :
@ -526,7 +470,7 @@ class U2STTester(U2STTrainer):
decode_cfg = self . config . decode
decode_cfg = self . config . decode
bleu_func = bleu_score . char_bleu if decode_cfg . error_rate_type == ' char-bleu ' else bleu_score . bleu
bleu_func = bleu_score . char_bleu if decode_cfg . error_rate_type == ' char-bleu ' else bleu_score . bleu
stride_ms = self . config . collator. stride_ms
stride_ms = self . config . stride_ms
hyps , refs = [ ] , [ ]
hyps , refs = [ ] , [ ]
len_refs , num_ins = 0 , 0
len_refs , num_ins = 0 , 0
num_frames = 0.0
num_frames = 0.0