diff --git a/paddlespeech/s2t/modules/embedding.py b/paddlespeech/s2t/modules/embedding.py index 54324c2f6..f41a7b5d4 100644 --- a/paddlespeech/s2t/modules/embedding.py +++ b/paddlespeech/s2t/modules/embedding.py @@ -110,12 +110,10 @@ class PositionalEncoding(nn.Layer, PositionalEncodingInterface): paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...) paddle.Tensor: for compatibility to RelPositionalEncoding, (batch=1, time, ...) """ - T = x.shape[1] assert offset + x.shape[ 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( offset, x.shape[1], self.max_len) - #TODO(Hui Zhang): using T = paddle.shape(x)[1], __getitem__ not support Tensor - pos_emb = self.pe[:, offset:offset + T] + pos_emb = self.pe[:, offset:offset + x.shape[1]] x = x * self.xscale + pos_emb return self.dropout(x), self.dropout(pos_emb) @@ -164,6 +162,5 @@ class RelPositionalEncoding(PositionalEncoding): 1] < self.max_len, "offset: {} + x.shape[1]: {} is larger than the max_len: {}".format( offset, x.shape[1], self.max_len) x = x * self.xscale - #TODO(Hui Zhang): using paddle.shape(x)[1], __getitem__ not support Tensor pos_emb = self.pe[:, offset:offset + x.shape[1]] return self.dropout(x), self.dropout(pos_emb) diff --git a/paddlespeech/s2t/modules/subsampling.py b/paddlespeech/s2t/modules/subsampling.py index 88451ddd7..2775988a7 100644 --- a/paddlespeech/s2t/modules/subsampling.py +++ b/paddlespeech/s2t/modules/subsampling.py @@ -139,8 +139,8 @@ class Conv2dSubsampling4(Conv2dSubsampling): """ x = x.unsqueeze(1) # (b, c=1, t, f) x = self.conv(x) - b, c, t, f = paddle.shape(x) - x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f])) + b, c, t, f = x.shape + x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, -1, c * f])) x, pos_emb = self.pos_enc(x, offset) return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2]