remove rel_shift, test=asr

pull/2755/head
yeyupiaoling 3 years ago
parent b297635a75
commit fe8bbcc226

@ -21,7 +21,6 @@ encoder_conf:
normalize_before: false normalize_before: false
activation_type: 'swish' activation_type: 'swish'
pos_enc_layer_type: 'rel_pos' pos_enc_layer_type: 'rel_pos'
do_rel_shift: false
time_reduction_layer_type: 'stream' time_reduction_layer_type: 'stream'
causal: true causal: true
use_dynamic_chunk: true use_dynamic_chunk: true

@ -21,7 +21,7 @@ encoder_conf:
normalize_before: false normalize_before: false
activation_type: 'swish' activation_type: 'swish'
pos_enc_layer_type: 'rel_pos' pos_enc_layer_type: 'rel_pos'
time_reduction_layer_type: 'conv2d' time_reduction_layer_type: 'conv1d'
# decoder related # decoder related
decoder: transformer decoder: transformer

@ -204,7 +204,6 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
n_head, n_head,
n_feat, n_feat,
dropout_rate, dropout_rate,
do_rel_shift=False,
adaptive_scale=False, adaptive_scale=False,
init_weights=False): init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object. """Construct an RelPositionMultiHeadedAttention object.
@ -229,7 +228,6 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
pos_bias_v = self.create_parameter( pos_bias_v = self.create_parameter(
(self.h, self.d_k), default_initializer=I.XavierUniform()) (self.h, self.d_k), default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_v', pos_bias_v) self.add_parameter('pos_bias_v', pos_bias_v)
self.do_rel_shift = do_rel_shift
self.adaptive_scale = adaptive_scale self.adaptive_scale = adaptive_scale
if self.adaptive_scale: if self.adaptive_scale:
ada_scale = self.create_parameter( ada_scale = self.create_parameter(
@ -369,8 +367,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True) matrix_bd = paddle.matmul(q_with_bias_v, p, transpose_y=True)
# Remove rel_shift since it is useless in speech recognition, # Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming. # and it requires special attention for streaming.
if self.do_rel_shift: # matrix_bd = self.rel_shift(matrix_bd)
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt( scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2) self.d_k) # (batch, head, time1, time2)

@ -515,7 +515,6 @@ class SqueezeformerEncoder(nn.Layer):
input_dropout_rate: float=0.1, input_dropout_rate: float=0.1,
pos_enc_layer_type: str="rel_pos", pos_enc_layer_type: str="rel_pos",
time_reduction_layer_type: str="conv1d", time_reduction_layer_type: str="conv1d",
do_rel_shift: bool=True,
feed_forward_dropout_rate: float=0.1, feed_forward_dropout_rate: float=0.1,
attention_dropout_rate: float=0.1, attention_dropout_rate: float=0.1,
cnn_module_kernel: int=31, cnn_module_kernel: int=31,
@ -549,8 +548,6 @@ class SqueezeformerEncoder(nn.Layer):
input_dropout_rate (float): Dropout rate of input projection layer. input_dropout_rate (float): Dropout rate of input projection layer.
pos_enc_layer_type (str): Self attention type. pos_enc_layer_type (str): Self attention type.
time_reduction_layer_type (str): Conv1d or Conv2d reduction layer. time_reduction_layer_type (str): Conv1d or Conv2d reduction layer.
do_rel_shift (bool): Whether to do relative shift
operation on rel-attention module.
cnn_module_kernel (int): Kernel size of CNN module. cnn_module_kernel (int): Kernel size of CNN module.
activation_type (str): Encoder activation function type. activation_type (str): Encoder activation function type.
cnn_module_kernel (int): Kernel size of convolution module. cnn_module_kernel (int): Kernel size of convolution module.
@ -590,7 +587,7 @@ class SqueezeformerEncoder(nn.Layer):
else: else:
encoder_selfattn_layer = RelPositionMultiHeadedAttention encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, encoder_dim, encoder_selfattn_layer_args = (attention_heads, encoder_dim,
attention_dropout_rate, do_rel_shift, attention_dropout_rate,
adaptive_scale, init_weights) adaptive_scale, init_weights)
# feed-forward module definition # feed-forward module definition

Loading…
Cancel
Save