add attention, common utils, hack paddle

pull/556/head
Hui Zhang 5 years ago
parent 16fa4245ec
commit 7635f98bce

@ -11,3 +11,307 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from typing import Union
from typing import Any
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__)
########### hcak logging #############
logger.warn = logging.warning
########### hcak paddle #############
paddle.bool = 'bool'
paddle.float16 = 'float16'
paddle.half = 'float16'
paddle.float32 = 'float32'
paddle.float = 'float32'
paddle.float64 = 'float64'
paddle.double = 'float64'
paddle.int8 = 'int8'
paddle.int16 = 'int16'
paddle.short = 'int16'
paddle.int32 = 'int32'
paddle.int = 'int32'
paddle.int64 = 'int64'
paddle.long = 'int64'
paddle.uint8 = 'uint8'
paddle.complex64 = 'complex64'
paddle.complex128 = 'complex128'
paddle.cdouble = 'complex128'
if not hasattr(paddle, 'softmax'):
logger.warn("register user softmax to paddle, remove this when fixed!")
setattr(paddle, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle, 'sigmoid'):
logger.warn("register user sigmoid to paddle, remove this when fixed!")
setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle, 'relu'):
logger.warn("register user relu to paddle, remove this when fixed!")
setattr(paddle, 'relu', paddle.nn.functional.relu)
def cat(xs, dim=0):
return paddle.concat(xs, axis=dim)
if not hasattr(paddle, 'cat'):
logger.warn(
"override cat of paddle if exists or register, remove this when fixed!")
paddle.cat = cat
########### hcak paddle.Tensor #############
if not hasattr(paddle.Tensor, 'numel'):
logger.warn(
"override numel of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.numel = paddle.numel
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
return xs.equal(paddle.to_tensor(ys, dtype=xs.dtype, place=xs.place))
if not hasattr(paddle.Tensor, 'eq'):
logger.warn(
"override eq of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.eq = eq
def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
return xs
if not hasattr(paddle.Tensor, 'contiguous'):
logger.warn(
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.contiguous = contiguous
def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
nargs = len(args)
assert (nargs <= 1)
s = paddle.shape(xs)
if nargs == 1:
return s[args[0]]
else:
return s
#`to_static` do not process `size` property, maybe some `paddle` api dependent on it.
logger.warn(
"override size of paddle.Tensor "
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
)
paddle.Tensor.size = size
def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
return xs.reshape(args)
if not hasattr(paddle.Tensor, 'view'):
logger.warn("register user view to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view = view
def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
return xs.reshape(ys.size())
if not hasattr(paddle.Tensor, 'view_as'):
logger.warn(
"register user view_as to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view_as = view_as
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
assert xs.shape == mask.shape
trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs)
return xs
if not hasattr(paddle.Tensor, 'masked_fill'):
logger.warn(
"register user masked_fill to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill = masked_fill
def masked_fill_(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
assert xs.shape == mask.shape
trues = paddle.ones_like(xs) * value
ret = paddle.where(mask, trues, xs)
paddle.assign(ret, output=xs)
if not hasattr(paddle.Tensor, 'masked_fill_'):
logger.warn(
"register user masked_fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill_ = masked_fill_
def fill_(xs: paddle.Tensor, value: Union[float, int]):
val = paddle.full_like(xs, value)
paddle.assign(val, output=xs)
if not hasattr(paddle.Tensor, 'fill_'):
logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.fill_ = fill_
def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor:
return paddle.tile(xs, size)
if not hasattr(paddle.Tensor, 'repeat'):
logger.warn(
"register user repeat to paddle.Tensor, remove this when fixed!")
paddle.Tensor.repeat = repeat
if not hasattr(paddle.Tensor, 'softmax'):
logger.warn(
"register user softmax to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle.Tensor, 'sigmoid'):
logger.warn(
"register user sigmoid to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid)
if not hasattr(paddle.Tensor, 'relu'):
logger.warn("register user relu to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu)
########### hcak paddle.nn.functional #############
def glu(x: paddle.Tensor, dim=-1) -> paddle.Tensor:
"""The gated linear unit (GLU) activation."""
a, b = x.split(2, axis=dim)
act_b = F.sigmoid(b)
return a * act_b
if not hasattr(paddle.nn.functional, 'glu'):
logger.warn(
"register user glu to paddle.nn.functional, remove this when fixed!")
setattr(paddle.nn.functional, 'glu', glu)
# def softplus(x):
# """Softplus function."""
# if hasattr(paddle.nn.functional, 'softplus'):
# #return paddle.nn.functional.softplus(x.float()).type_as(x)
# return paddle.nn.functional.softplus(x)
# else:
# raise NotImplementedError
# def gelu_accurate(x):
# """Gaussian Error Linear Units (GELU) activation."""
# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
# if not hasattr(gelu_accurate, "_a"):
# gelu_accurate._a = math.sqrt(2 / math.pi)
# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
# (x + 0.044715 * paddle.pow(x, 3))))
# def gelu(x):
# """Gaussian Error Linear Units (GELU) activation."""
# if hasattr(nn.functional, 'gelu'):
# #return nn.functional.gelu(x.float()).type_as(x)
# return nn.functional.gelu(x)
# else:
# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
# hack loss
def ctc_loss(logits,
labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean',
norm_by_times=True):
#logger.info("my ctc loss with norm by times")
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,
input_lengths, label_lengths)
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
logger.info(f"warpctc loss: {loss_out}/{loss_out.shape} ")
assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths)
elif reduction == 'sum':
loss_out = paddle.sum(loss_out)
logger.info(f"ctc loss: {loss_out}")
return loss_out
logger.warn(
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
)
F.ctc_loss = ctc_loss
########### hcak paddle.nn #############
class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer"""
def __init__(self, dim: int=-1):
super().__init__()
self.dim = dim
def forward(self, xs):
return glu(xs, dim=self.dim)
if not hasattr(paddle.nn, 'GLU'):
logger.warn("register user GLU to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'GLU', GLU)
# TODO(Hui Zhang): remove this Layer
class ConstantPad2d(nn.Layer):
"""Pads the input tensor boundaries with a constant value.
For N-dimensional padding, use paddle.nn.functional.pad().
"""
def __init__(self, padding: Union[tuple, list, int], value: float):
"""
Args:
paddle ([tuple]): the size of the padding.
If is int, uses the same padding in all boundaries.
If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom)
value ([flaot]): pad value
"""
self.padding = padding if isinstance(padding,
[tuple, list]) else [padding] * 4
self.value = value
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
return nn.functional.pad(
xs,
self.padding,
mode='constant',
value=self.value,
data_format='NCHW')
if not hasattr(paddle.nn, 'ConstantPad2d'):
logger.warn(
"register user ConstantPad2d to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d)

@ -10,248 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from typing import Union
from typing import Any
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__)
logger.warn = logging.warning
# TODO(Hui Zhang): remove this hack
paddle.bool = 'bool'
paddle.float16 = 'float16'
paddle.float32 = 'float32'
paddle.float64 = 'float64'
paddle.int8 = 'int8'
paddle.int16 = 'int16'
paddle.int32 = 'int32'
paddle.int64 = 'int64'
paddle.uint8 = 'uint8'
paddle.complex64 = 'complex64'
paddle.complex128 = 'complex128'
if not hasattr(paddle.Tensor, 'cat'):
logger.warn(
"override cat of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.cat = paddle.Tensor.concat
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
return xs.equal(paddle.to_tensor(ys, dtype=xs.dtype, place=xs.place))
if not hasattr(paddle.Tensor, 'eq'):
logger.warn(
"override eq of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.eq = eq
def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
return xs
if not hasattr(paddle.Tensor, 'contiguous'):
logger.warn(
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.contiguous = contiguous
def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
nargs = len(args)
assert (nargs <= 1)
s = paddle.shape(xs)
if nargs == 1:
return s[args[0]]
else:
return s
#`to_static` do not process `size` property, maybe some `paddle` api dependent on it.
logger.warn(
"override size of paddle.Tensor "
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
)
paddle.Tensor.size = size
def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
return xs.reshape(args)
if not hasattr(paddle.Tensor, 'view'):
logger.warn("register user view to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view = view
def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
assert xs.shape == mask.shape
trues = paddle.ones_like(xs) * value
xs = paddle.where(mask, trues, xs)
return xs
if not hasattr(paddle.Tensor, 'masked_fill'):
logger.warn(
"register user masked_fill to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill = masked_fill
def masked_fill_(xs: paddle.Tensor,
mask: paddle.Tensor,
value: Union[float, int]):
assert xs.shape == mask.shape
trues = paddle.ones_like(xs) * value
ret = paddle.where(mask, trues, xs)
paddle.assign(ret, output=xs)
if not hasattr(paddle.Tensor, 'masked_fill_'):
logger.warn(
"register user masked_fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill_ = masked_fill_
def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor:
return paddle.tile(xs, size)
if not hasattr(paddle.Tensor, 'repeat'):
logger.warn(
"register user repeat to paddle.Tensor, remove this when fixed!")
paddle.Tensor.repeat = repeat
# def softplus(x):
# """Softplus function."""
# if hasattr(paddle.nn.functional, 'softplus'):
# #return paddle.nn.functional.softplus(x.float()).type_as(x)
# return paddle.nn.functional.softplus(x)
# else:
# raise NotImplementedError
# def gelu_accurate(x):
# """Gaussian Error Linear Units (GELU) activation."""
# # [reference] https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/modules/gelu.py
# if not hasattr(gelu_accurate, "_a"):
# gelu_accurate._a = math.sqrt(2 / math.pi)
# return 0.5 * x * (1 + paddle.tanh(gelu_accurate._a *
# (x + 0.044715 * paddle.pow(x, 3))))
# def gelu(x):
# """Gaussian Error Linear Units (GELU) activation."""
# if hasattr(nn.functional, 'gelu'):
# #return nn.functional.gelu(x.float()).type_as(x)
# return nn.functional.gelu(x)
# else:
# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
def glu(x: paddle.Tensor, dim=-1) -> paddle.Tensor:
"""The gated linear unit (GLU) activation."""
a, b = x.split(2, axis=dim)
act_b = F.sigmoid(b)
return a * act_b
if not hasattr(paddle.nn.functional, 'glu'):
logger.warn(
"register user glu to paddle.nn.functional, remove this when fixed!")
setattr(paddle.nn.functional, 'glu', glu)
# TODO(Hui Zhang): remove this activation
class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer"""
def __init__(self, dim: int=-1):
super().__init__()
self.dim = dim
def forward(self, xs):
return glu(xs, dim=self.dim)
if not hasattr(paddle.nn, 'GLU'):
logger.warn("register user GLU to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'GLU', GLU)
# TODO(Hui Zhang): remove this Layer
class ConstantPad2d(nn.Layer):
"""Pads the input tensor boundaries with a constant value.
For N-dimensional padding, use paddle.nn.functional.pad().
"""
def __init__(self, padding: Union[tuple, list, int], value: float):
"""
Args:
paddle ([tuple]): the size of the padding.
If is int, uses the same padding in all boundaries.
If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom)
value ([flaot]): pad value
"""
self.padding = padding if isinstance(padding,
[tuple, list]) else [padding] * 4
self.value = value
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
return nn.functional.pad(
xs,
self.padding,
mode='constant',
value=self.value,
data_format='NCHW')
if not hasattr(paddle.nn, 'ConstantPad2d'):
logger.warn(
"register user ConstantPad2d to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d)
if not hasattr(paddle, 'softmax'):
logger.warn("register user softmax to paddle, remove this when fixed!")
setattr(paddle, 'softmax', paddle.nn.functional.softmax)
if not hasattr(paddle, 'sigmoid'):
logger.warn("register user softmax to paddle, remove this when fixed!")
setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid)
# hack loss
def ctc_loss(logits,
labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean',
norm_by_times=True):
#logger.info("my ctc loss with norm by times")
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,
input_lengths, label_lengths)
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
logger.info(f"warpctc loss: {loss_out}/{loss_out.shape} ")
assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths)
elif reduction == 'sum':
loss_out = paddle.sum(loss_out)
logger.info(f"ctc loss: {loss_out}")
return loss_out
logger.warn(
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
)
F.ctc_loss = ctc_loss

@ -25,7 +25,7 @@ from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ["brelu", "LinearGLUBlock", "ConstantPad2d", "ConvGLUBlock"] __all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"]
def brelu(x, t_min=0.0, t_max=24.0, name=None): def brelu(x, t_min=0.0, t_max=24.0, name=None):
@ -50,33 +50,6 @@ class LinearGLUBlock(nn.Layer):
return glu(self.fc(xs), dim=-1) return glu(self.fc(xs), dim=-1)
# TODO(Hui Zhang): remove this Layer
class ConstantPad2d(nn.Layer):
"""Pads the input tensor boundaries with a constant value.
For N-dimensional padding, use paddle.nn.functional.pad().
"""
def __init__(self, padding: Union[tuple, list, int], value: float):
"""
Args:
paddle ([tuple]): the size of the padding.
If is int, uses the same padding in all boundaries.
If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom)
value ([flaot]): pad value
"""
self.padding = padding if isinstance(padding,
[tuple, list]) else [padding] * 4
self.value = value
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
return nn.functional.pad(
xs,
self.padding,
mode='constant',
value=self.value,
data_format='NCHW')
class ConvGLUBlock(nn.Layer): class ConvGLUBlock(nn.Layer):
def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0, def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0,
dropout=0.): dropout=0.):
@ -159,3 +132,18 @@ class ConvGLUBlock(nn.Layer):
xs = self.layers(xs) # `[B, out_ch * 2, T ,1]` xs = self.layers(xs) # `[B, out_ch * 2, T ,1]`
xs = xs + residual xs = xs + residual
return xs return xs
def get_activation(act):
"""Return activation function."""
# Lazy load to avoid unused import
activation_funcs = {
"hardtanh": paddle.nn.Hardtanh,
"tanh": paddle.nn.Tanh,
"relu": paddle.nn.ReLU,
"selu": paddle.nn.SELU,
"swish": paddle.nn.Swish,
"gelu": paddle.nn.GELU
}
return activation_funcs[act]()

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Multi-Head Attention layer definition.""" """Multi-Head Attention layer definition."""
import math import math
import logging import logging
@ -26,6 +25,10 @@ logger = logging.getLogger(__name__)
__all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"] __all__ = ["MultiHeadedAttention", "RelPositionMultiHeadedAttention"]
# Relative Positional Encodings
# https://www.jianshu.com/p/c0608efcc26f
# https://zhuanlan.zhihu.com/p/344604604
class MultiHeadedAttention(nn.Layer): class MultiHeadedAttention(nn.Layer):
"""Multi-Head Attention layer.""" """Multi-Head Attention layer."""
@ -89,8 +92,8 @@ class MultiHeadedAttention(nn.Layer):
mask (paddle.Tensor): Mask, size (#batch, 1, time2) or mask (paddle.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2). (#batch, time1, time2).
Returns: Returns:
paddle.Tensor: Transformed value (#batch, time1, d_model) paddle.Tensor: Transformed value weighted
weighted by the attention score (#batch, time1, time2). by the attention score, (#batch, time1, d_model).
""" """
n_batch = value.size(0) n_batch = value.size(0)
if mask is not None: if mask is not None:
@ -126,8 +129,8 @@ class MultiHeadedAttention(nn.Layer):
torch.Tensor: Output tensor (#batch, time1, d_model). torch.Tensor: Output tensor (#batch, time1, d_model).
""" """
q, k, v = self.forward_qkv(query, key, value) q, k, v = self.forward_qkv(query, key, value)
scores = paddle.matmul(q, k.transpose( scores = paddle.matmul(q,
[0, 1, 3, 2])) / math.sqrt(self.d_k) k.transpose([0, 1, 3, 2])) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask) return self.forward_attention(v, scores, mask)
@ -147,76 +150,78 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False) self.linear_pos = nn.Linear(n_feat, n_feat, bias_attr=False)
# these two learnable bias are used in matrix c and matrix d # these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 # as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) #self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) #self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u) #torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v) #torch.nn.init.xavier_uniform_(self.pos_bias_v)
pos_bias_u = self.create_parameter(
[self.h, self.d_k], default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_u', pos_bias_u)
pos_bias_v = self.create_parameter(
(self.h, self.d_k), default_initializer=I.XavierUniform())
self.add_parameter('pos_bias_v', pos_bias_v)
def rel_shift(self, x, zero_triu: bool=False): def rel_shift(self, x, zero_triu: bool=False):
"""Compute relative positinal encoding. """Compute relative positinal encoding.
Args: Args:
x (torch.Tensor): Input tensor (batch, time, size). x (paddle.Tensor): Input tensor (batch, head, time1, time1).
zero_triu (bool): If true, return the lower triangular part of zero_triu (bool): If true, return the lower triangular part of
the matrix. the matrix.
Returns: Returns:
torch.Tensor: Output tensor. paddle.Tensor: Output tensor. (batch, head, time1, time1)
""" """
zero_pad = paddle.zeros(
(x.size(0), x.size(1), x.size(2), 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1)
zero_pad = torch.zeros( x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2))
(x.size()[0], x.size()[1], x.size()[2], 1), x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.size()[0],
x.size()[1], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
if zero_triu: if zero_triu:
ones = torch.ones((x.size(2), x.size(3))) ones = paddle.ones((x.size(2), x.size(3)))
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] x = x * paddle.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x return x
def forward(self, def forward(self,
query: torch.Tensor, query: paddle.Tensor,
key: torch.Tensor, key: paddle.Tensor,
value: torch.Tensor, value: paddle.Tensor,
pos_emb: torch.Tensor, pos_emb: paddle.Tensor,
mask: Optional[torch.Tensor]): mask: Optional[paddle.Tensor]):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args: Args:
query (torch.Tensor): Query tensor (#batch, time1, size). query (paddle.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size). key (paddle.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size). value (paddle.Tensor): Value tensor (#batch, time2, size).
pos_emb (torch.Tensor): Positional embedding tensor pos_emb (paddle.Tensor): Positional embedding tensor
(#batch, time2, size). (#batch, time1, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or mask (paddle.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2). (#batch, time1, time2).
Returns: Returns:
torch.Tensor: Output tensor (#batch, time1, d_model). paddle.Tensor: Output tensor (#batch, time1, d_model).
""" """
q, k, v = self.forward_qkv(query, key, value) q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k) q = q.transpose([0, 2, 1, 3]) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0) n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k) p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
# (batch, head, time1, d_k) # (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) q_with_bias_u = (q + self.pos_bias_u).transpose([0, 2, 1, 3])
# (batch, head, time1, d_k) # (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) q_with_bias_v = (q + self.pos_bias_v).transpose([0, 2, 1, 3])
# compute attention score # compute attention score
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3 # as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2) # (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) matrix_ac = torch.matmul(q_with_bias_u, k.transpose([0, 1, 3, 2]))
# compute matrix b and matrix d # compute matrix b and matrix d
# (batch, head, time1, time2) # (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) matrix_bd = torch.matmul(q_with_bias_v, p.transpose([0, 1, 3, 2]))
# Remove rel_shift since it is useless in speech recognition, # Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming. # and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd) # matrix_bd = self.rel_shift(matrix_bd)

@ -27,9 +27,6 @@ logger = logging.getLogger(__name__)
__all__ = ["PositionalEncoding", "RelPositionalEncoding"] __all__ = ["PositionalEncoding", "RelPositionalEncoding"]
# TODO(Hui Zhang): remove this hack
paddle.float32 = 'float32'
class PositionalEncoding(nn.Layer): class PositionalEncoding(nn.Layer):
def __init__(self, def __init__(self,
@ -122,11 +119,11 @@ class RelPositionalEncoding(PositionalEncoding):
paddle.Tensor: Encoded tensor (batch, time, `*`). paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`). paddle.Tensor: Positional embedding tensor (1, time, `*`).
""" """
T = paddle.shape()[1] #T = paddle.shape()[1]
assert offset + T < self.max_len #assert offset + T < self.max_len
#assert offset + x.size(1) < self.max_len assert offset + x.size(1) < self.max_len
#self.pe = self.pe.to(x.device) #self.pe = self.pe.to(x.device)
x = x * self.xscale x = x * self.xscale
#pos_emb = self.pe[:, offset:offset + x.size(1)] pos_emb = self.pe[:, offset:offset + x.size(1)]
pos_emb = self.pe[:, offset:offset + T] #pos_emb = self.pe[:, offset:offset + T]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)

@ -0,0 +1,113 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unility functions for Transformer."""
import math
import logging
from typing import Tuple, List
import paddle
logger = logging.getLogger(__name__)
__all__ = ["pad_list", "add_sos_eos", "remove_duplicates_and_blank", "log_add"]
IGNORE_ID = -1
def pad_list(xs: List[paddle.Tensor], pad_value: int):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [paddle.ones(4), paddle.ones(2), paddle.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
max_len = max([x.size(0) for x in xs])
pad = paddle.zeros(n_batch, max_len, dtype=xs[0].dtype)
pad = pad.fill_(pad_value)
for i in range(n_batch):
pad[i, :xs[i].size(0)] = xs[i]
return pad
def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
ignore_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Add <sos> and <eos> labels.
Args:
ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of <sos>
eos (int): index of <eeos>
ignore_id (int): index of padding
Returns:
ys_in (paddle.Tensor) : (B, Lmax + 1)
ys_out (paddle.Tensor) : (B, Lmax + 1)
Examples:
>>> sos_id = 10
>>> eos_id = 11
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=paddle.int32)
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
>>> ys_in
tensor([[10, 1, 2, 3, 4, 5],
[10, 4, 5, 6, 11, 11],
[10, 7, 8, 9, 11, 11]])
>>> ys_out
tensor([[ 1, 2, 3, 4, 5, 11],
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
_sos = paddle.to_tensor(
[sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
_eos = paddle.to_tensor(
[eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def log_add(args: List[int]) -> float:
"""
Stable log add
"""
if all(a == -float('inf') for a in args):
return -float('inf')
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from paddle import nn from paddle import nn
__all__ = [ __all__ = [
@ -36,7 +37,7 @@ def gradient_norm(layer: nn.Layer):
grad_norm_dict = {} grad_norm_dict = {}
for name, param in layer.state_dict().items(): for name, param in layer.state_dict().items():
if param.trainable: if param.trainable:
grad = param.gradient() grad = param.gradient() # return numpy.ndarray
grad_norm_dict[name] = np.linalg.norm(grad) / grad.size grad_norm_dict[name] = np.linalg.norm(grad) / grad.size
return grad_norm_dict return grad_norm_dict

@ -0,0 +1,43 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import logging
from typing import Tuple, List
import paddle
logger = logging.getLogger(__name__)
__all__ = ["th_accuracy"]
def th_accuracy(pad_outputs: paddle.Tensor,
pad_targets: paddle.Tensor,
ignore_label: int) -> float:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
numerator = paddle.sum(
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = paddle.sum(mask)
return float(numerator) / float(denominator)

@ -15,10 +15,9 @@
import paddle import paddle
import numpy as np import numpy as np
from deepspeech.models.network import DeepSpeech2 from deepspeech.models.deepspeech2 import DeepSpeech2Model
if __name__ == '__main__': if __name__ == '__main__':
batch_size = 2 batch_size = 2
feat_dim = 161 feat_dim = 161
max_len = 100 max_len = 100
@ -28,15 +27,10 @@ if __name__ == '__main__':
text = np.array([[1, 2], [1, 2]], dtype='int32') text = np.array([[1, 2], [1, 2]], dtype='int32')
text_len = np.array([2] * batch_size, dtype='int32') text_len = np.array([2] * batch_size, dtype='int32')
place = paddle.CUDAPlace(0) audio = paddle.to_tensor(audio, dtype='float32')
audio = paddle.to_tensor( audio_len = paddle.to_tensor(audio_len, dtype='int64')
audio, dtype='float32', place=place, stop_gradient=True) text = paddle.to_tensor(text, dtype='int32')
audio_len = paddle.to_tensor( text_len = paddle.to_tensor(text_len, dtype='int64')
audio_len, dtype='int64', place=place, stop_gradient=True)
text = paddle.to_tensor(
text, dtype='int32', place=place, stop_gradient=True)
text_len = paddle.to_tensor(
text_len, dtype='int64', place=place, stop_gradient=True)
print(audio.shape) print(audio.shape)
print(audio_len.shape) print(audio_len.shape)
@ -44,7 +38,7 @@ if __name__ == '__main__':
print(text_len.shape) print(text_len.shape)
print("-----------------") print("-----------------")
model = DeepSpeech2( model = DeepSpeech2Model(
feat_size=feat_dim, feat_size=feat_dim,
dict_size=10, dict_size=10,
num_conv_layers=2, num_conv_layers=2,
@ -56,7 +50,7 @@ if __name__ == '__main__':
print('probs.shape', probs.shape) print('probs.shape', probs.shape)
print("-----------------") print("-----------------")
model2 = DeepSpeech2( model2 = DeepSpeech2Model(
feat_size=feat_dim, feat_size=feat_dim,
dict_size=10, dict_size=10,
num_conv_layers=2, num_conv_layers=2,
@ -68,7 +62,7 @@ if __name__ == '__main__':
print('probs.shape', probs.shape) print('probs.shape', probs.shape)
print("-----------------") print("-----------------")
model3 = DeepSpeech2( model3 = DeepSpeech2Model(
feat_size=feat_dim, feat_size=feat_dim,
dict_size=10, dict_size=10,
num_conv_layers=2, num_conv_layers=2,
@ -80,7 +74,7 @@ if __name__ == '__main__':
print('probs.shape', probs.shape) print('probs.shape', probs.shape)
print("-----------------") print("-----------------")
model4 = DeepSpeech2( model4 = DeepSpeech2Model(
feat_size=feat_dim, feat_size=feat_dim,
dict_size=10, dict_size=10,
num_conv_layers=2, num_conv_layers=2,
@ -92,7 +86,7 @@ if __name__ == '__main__':
print('probs.shape', probs.shape) print('probs.shape', probs.shape)
print("-----------------") print("-----------------")
model5 = DeepSpeech2( model5 = DeepSpeech2Model(
feat_size=feat_dim, feat_size=feat_dim,
dict_size=10, dict_size=10,
num_conv_layers=2, num_conv_layers=2,

Loading…
Cancel
Save