|
|
|
@ -61,12 +61,14 @@ class U2BaseModel(nn.Module):
|
|
|
|
|
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
|
|
|
|
# network architecture
|
|
|
|
|
default = CfgNode()
|
|
|
|
|
# allow add new item when merge_with_file
|
|
|
|
|
default.set_new_allowed(True)
|
|
|
|
|
default.cmvn_file = ""
|
|
|
|
|
default.cmvn_file_type = "npz"
|
|
|
|
|
default.input_dim = 0
|
|
|
|
|
default.output_dim = 0
|
|
|
|
|
# encoder related
|
|
|
|
|
default.encoder = 'conformer'
|
|
|
|
|
default.encoder = 'transformer'
|
|
|
|
|
default.encoder_conf = CfgNode(
|
|
|
|
|
dict(
|
|
|
|
|
output_size=256, # dimension of attention
|
|
|
|
@ -78,11 +80,12 @@ class U2BaseModel(nn.Module):
|
|
|
|
|
attention_dropout_rate=0.0,
|
|
|
|
|
input_layer='conv2d', # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
|
|
|
|
normalize_before=True,
|
|
|
|
|
cnn_module_kernel=15,
|
|
|
|
|
use_cnn_module=True,
|
|
|
|
|
activation_type='swish',
|
|
|
|
|
pos_enc_layer_type='rel_pos',
|
|
|
|
|
selfattention_layer_type='rel_selfattn', ))
|
|
|
|
|
# use_cnn_module=True,
|
|
|
|
|
# cnn_module_kernel=15,
|
|
|
|
|
# activation_type='swish',
|
|
|
|
|
# pos_enc_layer_type='rel_pos',
|
|
|
|
|
# selfattention_layer_type='rel_selfattn',
|
|
|
|
|
))
|
|
|
|
|
# decoder related
|
|
|
|
|
default.decoder = 'transformer'
|
|
|
|
|
default.decoder_conf = CfgNode(
|
|
|
|
|