From 46088c0a16aa1476c095b80fee551c7df4a8ce71 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 26 Sep 2022 12:19:30 +0000 Subject: [PATCH] elimiate attn transpose --- paddlespeech/s2t/modules/attention.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index c02de15e8..67bb869ed 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -271,7 +271,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): and `head * d_k == size` """ q, k, v = self.forward_qkv(query, key, value) - q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) + # q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k) # when export onnx model, for 1st chunk, we feed # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) @@ -302,9 +302,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) # (batch, head, time1, d_k) - q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) + # q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3]) + q_with_bias_u = q + self.pos_bias_u.unsqueeze(1) # (batch, head, time1, d_k) - q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) + # q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3]) + q_with_bias_v = q + self.pos_bias_v.unsqueeze(1) # compute attention score # first compute matrix a and matrix c