|
|
|
@ -661,9 +661,7 @@ class U2BaseModel(nn.Layer):
|
|
|
|
|
xs, offset, required_cache_size, subsampling_cache,
|
|
|
|
|
elayers_output_cache, conformer_cnn_cache)
|
|
|
|
|
|
|
|
|
|
# @jit.to_static([
|
|
|
|
|
# paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D]
|
|
|
|
|
# ])
|
|
|
|
|
# @jit.to_static
|
|
|
|
|
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
|
|
|
|
|
""" Export interface for c++ call, apply linear transform and log
|
|
|
|
|
softmax before ctc
|
|
|
|
@ -830,6 +828,7 @@ class U2Model(U2BaseModel):
|
|
|
|
|
Returns:
|
|
|
|
|
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
|
|
|
|
|
"""
|
|
|
|
|
# cmvn
|
|
|
|
|
if configs['cmvn_file'] is not None:
|
|
|
|
|
mean, istd = load_cmvn(configs['cmvn_file'],
|
|
|
|
|
configs['cmvn_file_type'])
|
|
|
|
@ -839,11 +838,13 @@ class U2Model(U2BaseModel):
|
|
|
|
|
else:
|
|
|
|
|
global_cmvn = None
|
|
|
|
|
|
|
|
|
|
# input & output dim
|
|
|
|
|
input_dim = configs['input_dim']
|
|
|
|
|
vocab_size = configs['output_dim']
|
|
|
|
|
assert input_dim != 0, input_dim
|
|
|
|
|
assert vocab_size != 0, vocab_size
|
|
|
|
|
|
|
|
|
|
# encoder
|
|
|
|
|
encoder_type = configs.get('encoder', 'transformer')
|
|
|
|
|
logger.info(f"U2 Encoder type: {encoder_type}")
|
|
|
|
|
if encoder_type == 'transformer':
|
|
|
|
@ -855,17 +856,21 @@ class U2Model(U2BaseModel):
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"not support encoder type:{encoder_type}")
|
|
|
|
|
|
|
|
|
|
# decoder
|
|
|
|
|
decoder = TransformerDecoder(vocab_size,
|
|
|
|
|
encoder.output_size(),
|
|
|
|
|
**configs['decoder_conf'])
|
|
|
|
|
|
|
|
|
|
# ctc decoder and ctc loss
|
|
|
|
|
model_conf = configs['model_conf']
|
|
|
|
|
ctc = CTCDecoder(
|
|
|
|
|
odim=vocab_size,
|
|
|
|
|
enc_n_units=encoder.output_size(),
|
|
|
|
|
blank_id=0,
|
|
|
|
|
dropout_rate=0.0,
|
|
|
|
|
dropout_rate=model_conf['ctc_dropout_rate'],
|
|
|
|
|
reduction=True, # sum
|
|
|
|
|
batch_average=True, # sum / batch_size
|
|
|
|
|
grad_norm_type='instance')
|
|
|
|
|
grad_norm_type=model_conf['ctc_grad_norm_type'])
|
|
|
|
|
|
|
|
|
|
return vocab_size, encoder, decoder, ctc
|
|
|
|
|
|
|
|
|
|