You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/t2s/modules/transformer/subsampling.py

80 lines
2.7 KiB

3 years ago
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""Subsampling layer definition."""
import paddle
from paddle import nn
3 years ago
from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding
class Conv2dSubsampling(nn.Layer):
3 years ago
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int):
Input dimension.
odim (int):
Output dimension.
dropout_rate (float):
Dropout rate.
pos_enc (nn.Layer):
Custom position encoding layer.
3 years ago
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = nn.Sequential(
nn.Conv2D(1, odim, 3, 2),
nn.ReLU(),
nn.Conv2D(odim, odim, 3, 2),
nn.ReLU(), )
self.out = nn.Sequential(
nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
3 years ago
pos_enc if pos_enc is not None else
PositionalEncoding(odim, dropout_rate), )
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (Tensor):
Input tensor (#batch, time, idim).
x_mask (Tensor):
Input mask (#batch, 1, time).
Returns:
Tensor:
Subsampled tensor (#batch, time', odim), where time' = time // 4.
Tensor:
Subsampled mask (#batch, 1, time'), where time' = time // 4.
3 years ago
"""
# (b, c, t, f)
x = x.unsqueeze(1)
x = self.conv(x)
b, c, t, f = paddle.shape(x)
3 years ago
x = self.out(x.transpose([0, 2, 1, 3]).reshape([b, t, c * f]))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-2:2]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError(
"Support only `-1` (for `reset_parameters`).")
return self.out[key]