add more ctc conf

pull/820/head
Hui Zhang 3 years ago
parent 41ed7a184c
commit 890a28f9bf

@ -661,9 +661,7 @@ class U2BaseModel(nn.Layer):
xs, offset, required_cache_size, subsampling_cache, xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache) elayers_output_cache, conformer_cnn_cache)
# @jit.to_static([ # @jit.to_static
# paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D]
# ])
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log """ Export interface for c++ call, apply linear transform and log
softmax before ctc softmax before ctc
@ -830,6 +828,7 @@ class U2Model(U2BaseModel):
Returns: Returns:
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
""" """
# cmvn
if configs['cmvn_file'] is not None: if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'], mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type']) configs['cmvn_file_type'])
@ -839,11 +838,13 @@ class U2Model(U2BaseModel):
else: else:
global_cmvn = None global_cmvn = None
# input & output dim
input_dim = configs['input_dim'] input_dim = configs['input_dim']
vocab_size = configs['output_dim'] vocab_size = configs['output_dim']
assert input_dim != 0, input_dim assert input_dim != 0, input_dim
assert vocab_size != 0, vocab_size assert vocab_size != 0, vocab_size
# encoder
encoder_type = configs.get('encoder', 'transformer') encoder_type = configs.get('encoder', 'transformer')
logger.info(f"U2 Encoder type: {encoder_type}") logger.info(f"U2 Encoder type: {encoder_type}")
if encoder_type == 'transformer': if encoder_type == 'transformer':
@ -855,17 +856,21 @@ class U2Model(U2BaseModel):
else: else:
raise ValueError(f"not support encoder type:{encoder_type}") raise ValueError(f"not support encoder type:{encoder_type}")
# decoder
decoder = TransformerDecoder(vocab_size, decoder = TransformerDecoder(vocab_size,
encoder.output_size(), encoder.output_size(),
**configs['decoder_conf']) **configs['decoder_conf'])
# ctc decoder and ctc loss
model_conf = configs['model_conf']
ctc = CTCDecoder( ctc = CTCDecoder(
odim=vocab_size, odim=vocab_size,
enc_n_units=encoder.output_size(), enc_n_units=encoder.output_size(),
blank_id=0, blank_id=0,
dropout_rate=0.0, dropout_rate=model_conf['ctc_dropout_rate'],
reduction=True, # sum reduction=True, # sum
batch_average=True, # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type='instance') grad_norm_type=model_conf['ctc_grad_norm_type'])
return vocab_size, encoder, decoder, ctc return vocab_size, encoder, decoder, ctc

@ -413,26 +413,26 @@ class U2STBaseModel(nn.Layer):
best_hyps = best_hyps[:, 1:] best_hyps = best_hyps[:, 1:]
return best_hyps return best_hyps
@jit.to_static # @jit.to_static
def subsampling_rate(self) -> int: def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the """ Export interface for c++ call, return subsampling_rate of the
model model
""" """
return self.encoder.embed.subsampling_rate return self.encoder.embed.subsampling_rate
@jit.to_static # @jit.to_static
def right_context(self) -> int: def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model """ Export interface for c++ call, return right_context of the model
""" """
return self.encoder.embed.right_context return self.encoder.embed.right_context
@jit.to_static # @jit.to_static
def sos_symbol(self) -> int: def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model """ Export interface for c++ call, return sos symbol id of the model
""" """
return self.sos return self.sos
@jit.to_static # @jit.to_static
def eos_symbol(self) -> int: def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model """ Export interface for c++ call, return eos symbol id of the model
""" """
@ -468,7 +468,7 @@ class U2STBaseModel(nn.Layer):
xs, offset, required_cache_size, subsampling_cache, xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache) elayers_output_cache, conformer_cnn_cache)
@jit.to_static # @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log """ Export interface for c++ call, apply linear transform and log
softmax before ctc softmax before ctc
@ -643,14 +643,16 @@ class U2STModel(U2STBaseModel):
decoder = TransformerDecoder(vocab_size, decoder = TransformerDecoder(vocab_size,
encoder.output_size(), encoder.output_size(),
**configs['decoder_conf']) **configs['decoder_conf'])
# ctc decoder and ctc loss
model_conf = configs['model_conf']
ctc = CTCDecoder( ctc = CTCDecoder(
odim=vocab_size, odim=vocab_size,
enc_n_units=encoder.output_size(), enc_n_units=encoder.output_size(),
blank_id=0, blank_id=0,
dropout_rate=0.0, dropout_rate=model_conf['ctc_dropout_rate'],
reduction=True, # sum reduction=True, # sum
batch_average=True, # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type='instance') grad_norm_type=model_conf['ctc_grad_norm_type'])
return vocab_size, encoder, (st_decoder, decoder, ctc) return vocab_size, encoder, (st_decoder, decoder, ctc)
else: else:

@ -76,6 +76,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -71,6 +71,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -76,6 +76,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -69,6 +69,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -72,6 +72,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -76,6 +76,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -69,6 +69,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -72,6 +72,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -58,6 +58,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -68,6 +68,8 @@ model:
model_conf: model_conf:
asr_weight: 0.0 asr_weight: 0.0
ctc_weight: 0.0 ctc_weight: 0.0
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -68,6 +68,8 @@ model:
model_conf: model_conf:
asr_weight: 0.5 asr_weight: 0.5
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -66,6 +66,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -76,6 +76,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -69,6 +69,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -72,6 +72,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

@ -66,6 +66,8 @@ model:
# hybrid CTC/attention # hybrid CTC/attention
model_conf: model_conf:
ctc_weight: 0.3 ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
lsm_weight: 0.1 # label smoothing option lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false length_normalized_loss: false

Loading…
Cancel
Save