From 1407014758f99826a2de6f3490d2111c9e6c6831 Mon Sep 17 00:00:00 2001 From: Wang Huan Date: Mon, 18 Sep 2023 03:00:27 +0000 Subject: [PATCH] fix view use, because paddle support view now --- paddlespeech/s2t/modules/attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index 10ab3eaea..fb727b133 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -79,9 +79,9 @@ class MultiHeadedAttention(nn.Layer): """ n_batch = query.shape[0] - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = self.linear_q(query).reshape([n_batch, -1, self.h, self.d_k]) + k = self.linear_k(key).reshape([n_batch, -1, self.h, self.d_k]) + v = self.linear_v(value).reshape([n_batch, -1, self.h, self.d_k]) q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k) @@ -129,8 +129,8 @@ class MultiHeadedAttention(nn.Layer): p_attn = self.dropout(attn) x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k) - x = x.transpose([0, 2, 1, 3]).view(n_batch, -1, self.h * - self.d_k) # (batch, time1, d_model) + x = x.transpose([0, 2, 1, 3]).reshape([n_batch, -1, self.h * + self.d_k]) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) @@ -280,8 +280,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) x_padded = paddle.cat([zero_pad, x], dim=-1) - x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, - x.shape[2]) + x_padded = x_padded.view([x.shape[0], x.shape[1], x.shape[3] + 1, + x.shape[2]]) x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] if zero_triu: @@ -349,7 +349,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): new_cache = paddle.concat((k, v), axis=-1) n_batch_pos = pos_emb.shape[0] - p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = self.linear_pos(pos_emb).reshape([n_batch_pos, -1, self.h, self.d_k]) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k) # (batch, head, time1, d_k)