|
|
|
@ -26,10 +26,7 @@ from paddlespeech.s2t.utils.log import Log
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
"MultiHeadedAttention", "RelPositionMultiHeadedAttention",
|
|
|
|
|
"RelPositionMultiHeadedAttention2"
|
|
|
|
|
]
|
|
|
|
|
__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"]
|
|
|
|
|
|
|
|
|
|
# Relative Positional Encodings
|
|
|
|
|
# https://www.jianshu.com/p/c0608efcc26f
|
|
|
|
@ -203,7 +200,10 @@ class MultiHeadedAttention(nn.Layer):
|
|
|
|
|
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|
|
|
|
"""Multi-Head Attention layer with relative position encoding."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, n_head, n_feat, dropout_rate,
|
|
|
|
|
def __init__(self,
|
|
|
|
|
n_head,
|
|
|
|
|
n_feat,
|
|
|
|
|
dropout_rate,
|
|
|
|
|
do_rel_shift=False,
|
|
|
|
|
adaptive_scale=False,
|
|
|
|
|
init_weights=False):
|
|
|
|
|