Merge pull request #1099 from zh794390558/ctc

[asr] remove ctc grad norm type in config
pull/1103/head
Hui Zhang 3 years ago committed by GitHub
commit 4aa6a76036
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,7 +43,7 @@ model:
fc_layers_size_list: -1,
use_gru: False
blank_id: 0
ctc_grad_norm_type: null
training:
n_epoch: 65

@ -41,7 +41,7 @@ model:
use_gru: False
share_rnn_weights: True
blank_id: 0
ctc_grad_norm_type: null
training:
n_epoch: 50

@ -43,7 +43,7 @@ model:
fc_layers_size_list: 512, 256
use_gru: False
blank_id: 0
ctc_grad_norm_type: null
training:
n_epoch: 50

@ -76,8 +76,6 @@ model:
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -69,8 +69,6 @@ model:
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -72,8 +72,6 @@ model:
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -29,8 +29,6 @@ model:
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
@ -81,7 +79,7 @@ training:
optim_conf:
lr: 0.004
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0

@ -30,8 +30,6 @@ model:
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -68,8 +68,6 @@ model:
model_conf:
asr_weight: 0.0
ctc_weight: 0.0
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -68,8 +68,6 @@ model:
model_conf:
asr_weight: 0.5
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -68,8 +68,6 @@ model:
model_conf:
asr_weight: 0.0
ctc_weight: 0.0
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -68,8 +68,6 @@ model:
model_conf:
asr_weight: 0.5
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -66,8 +66,6 @@ model:
# hybrid CTC/attention
model_conf:
ctc_weight: 0.5
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -42,7 +42,7 @@ model:
use_gru: False
share_rnn_weights: True
blank_id: 0
ctc_grad_norm_type: null
training:
n_epoch: 5

@ -44,7 +44,7 @@ model:
fc_layers_size_list: 512, 256
use_gru: True
blank_id: 0
ctc_grad_norm_type: null
training:
n_epoch: 5

@ -76,8 +76,6 @@ model:
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -69,8 +69,6 @@ model:
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -72,8 +72,6 @@ model:
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -66,8 +66,6 @@ model:
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -33,8 +33,6 @@ model:
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

@ -129,7 +129,7 @@ class DeepSpeech2Model(nn.Layer):
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True, #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
ctc_grad_norm_type='instance', ))
ctc_grad_norm_type=None,))
if config is not None:
config.merge_from_other_cfg(default)
return default
@ -143,7 +143,7 @@ class DeepSpeech2Model(nn.Layer):
use_gru=False,
share_rnn_weights=True,
blank_id=0,
ctc_grad_norm_type='instance'):
ctc_grad_norm_type=None):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
@ -220,16 +220,14 @@ class DeepSpeech2Model(nn.Layer):
"""
model = cls(
feat_size=dataloader.collate_fn.feature_size,
#feat_size=dataloader.dataset.feature_size,
dict_size=dataloader.collate_fn.vocab_size,
#dict_size=dataloader.dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights,
blank_id=config.model.blank_id,
ctc_grad_norm_type=config.model.ctc_grad_norm_type, )
ctc_grad_norm_type=config.get('ctc_grad_norm_type', None), )
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
@ -257,7 +255,7 @@ class DeepSpeech2Model(nn.Layer):
use_gru=config.use_gru,
share_rnn_weights=config.share_rnn_weights,
blank_id=config.blank_id,
ctc_grad_norm_type=config.ctc_grad_norm_type, )
ctc_grad_norm_type=config.get('ctc_grad_norm_type', None), )
return model

@ -255,7 +255,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
fc_layers_size_list=[512, 256],
use_gru=True, #Use gru if set True. Use simple rnn if set False.
blank_id=0, # index of blank in vocob.txt
ctc_grad_norm_type='instance', ))
ctc_grad_norm_type=None, ))
if config is not None:
config.merge_from_other_cfg(default)
return default
@ -272,7 +272,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
fc_layers_size_list=[512, 256],
use_gru=False,
blank_id=0,
ctc_grad_norm_type='instance', ):
ctc_grad_norm_type=None, ):
super().__init__()
self.encoder = CRNNEncoder(
feat_size=feat_size,
@ -361,7 +361,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
fc_layers_size_list=config.model.fc_layers_size_list,
use_gru=config.model.use_gru,
blank_id=config.model.blank_id,
ctc_grad_norm_type=config.model.ctc_grad_norm_type, )
ctc_grad_norm_type=config.get('ctc_grad_norm_type', None), )
infos = Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
@ -391,7 +391,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
fc_layers_size_list=config.fc_layers_size_list,
use_gru=config.use_gru,
blank_id=config.blank_id,
ctc_grad_norm_type=config.ctc_grad_norm_type, )
ctc_grad_norm_type=config.get('ctc_grad_norm_type', None), )
return model

@ -894,14 +894,16 @@ class U2Model(U2DecodeModel):
# ctc decoder and ctc loss
model_conf = configs['model_conf']
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
grad_norm_type = model_conf.get('ctc_grad_norm_type', None)
ctc = CTCDecoder(
odim=vocab_size,
enc_n_units=encoder.output_size(),
blank_id=0,
dropout_rate=model_conf['ctc_dropoutrate'],
dropout_rate=dropout_rate,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=model_conf['ctc_grad_norm_type'])
grad_norm_type=grad_norm_type)
return vocab_size, encoder, decoder, ctc

@ -655,14 +655,16 @@ class U2STModel(U2STBaseModel):
**configs['decoder_conf'])
# ctc decoder and ctc loss
model_conf = configs['model_conf']
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
grad_norm_type = model_conf.get('ctc_grad_norm_type', None)
ctc = CTCDecoder(
odim=vocab_size,
enc_n_units=encoder.output_size(),
blank_id=0,
dropout_rate=model_conf['ctc_dropoutrate'],
dropout_rate=dropout_rate,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=model_conf['ctc_grad_norm_type'])
grad_norm_type=grad_norm_type)
return vocab_size, encoder, (st_decoder, decoder, ctc)
else:

@ -74,8 +74,6 @@ class TestU2Model(unittest.TestCase):
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
length_normalized_loss: false
"""
cfg = CN().load_cfg(conf_str)
@ -128,8 +126,6 @@ class TestU2Model(unittest.TestCase):
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
ctc_dropoutrate: 0.0
ctc_grad_norm_type: null
length_normalized_loss: false
"""
cfg = CN().load_cfg(conf_str)

Loading…
Cancel
Save