|
|
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""Attention modules for RNN."""
|
|
|
|
import paddle
|
|
|
|
import paddle.nn.functional as F
|
|
|
|
from paddle import nn
|
|
|
|
|
|
|
|
from paddlespeech.t2s.modules.masked_fill import masked_fill
|
|
|
|
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
|
|
|
|
|
|
|
|
|
|
|
|
def _apply_attention_constraint(e,
|
|
|
|
last_attended_idx,
|
|
|
|
backward_window=1,
|
|
|
|
forward_window=3):
|
|
|
|
"""Apply monotonic attention constraint.
|
|
|
|
|
|
|
|
This function apply the monotonic attention constraint
|
|
|
|
introduced in `Deep Voice 3: Scaling
|
|
|
|
Text-to-Speech with Convolutional Sequence Learning`_.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
e(Tensor): Attention energy before applying softmax (1, T).
|
|
|
|
last_attended_idx(int): The index of the inputs of the last attended [0, T].
|
|
|
|
backward_window(int, optional, optional): Backward window size in attention constraint. (Default value = 1)
|
|
|
|
forward_window(int, optional, optional): Forward window size in attetion constraint. (Default value = 3)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor: Monotonic constrained attention energy (1, T).
|
|
|
|
|
|
|
|
.. _`Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning`:
|
|
|
|
https://arxiv.org/abs/1710.07654
|
|
|
|
|
|
|
|
"""
|
|
|
|
if paddle.shape(e)[0] != 1:
|
|
|
|
raise NotImplementedError(
|
|
|
|
"Batch attention constraining is not yet supported.")
|
|
|
|
backward_idx = last_attended_idx - backward_window
|
|
|
|
forward_idx = last_attended_idx + forward_window
|
|
|
|
if backward_idx > 0:
|
|
|
|
e[:, :backward_idx] = -float("inf")
|
|
|
|
if forward_idx < paddle.shape(e)[1]:
|
|
|
|
e[:, forward_idx:] = -float("inf")
|
|
|
|
return e
|
|
|
|
|
|
|
|
|
|
|
|
class AttLoc(nn.Layer):
|
|
|
|
"""location-aware attention module.
|
|
|
|
|
|
|
|
Reference: Attention-Based Models for Speech Recognition
|
|
|
|
(https://arxiv.org/pdf/1506.07503.pdf)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
eprojs (int): projection-units of encoder
|
|
|
|
dunits (int): units of decoder
|
|
|
|
att_dim (int): attention dimension
|
|
|
|
aconv_chans (int): channels of attention convolution
|
|
|
|
aconv_filts (int): filter size of attention convolution
|
|
|
|
han_mode (bool): flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
eprojs,
|
|
|
|
dunits,
|
|
|
|
att_dim,
|
|
|
|
aconv_chans,
|
|
|
|
aconv_filts,
|
|
|
|
han_mode=False):
|
|
|
|
super().__init__()
|
|
|
|
self.mlp_enc = nn.Linear(eprojs, att_dim)
|
|
|
|
self.mlp_dec = nn.Linear(dunits, att_dim, bias_attr=False)
|
|
|
|
self.mlp_att = nn.Linear(aconv_chans, att_dim, bias_attr=False)
|
|
|
|
self.loc_conv = nn.Conv2D(
|
|
|
|
1,
|
|
|
|
aconv_chans,
|
|
|
|
(1, 2 * aconv_filts + 1),
|
|
|
|
padding=(0, aconv_filts),
|
|
|
|
bias_attr=False, )
|
|
|
|
self.gvec = nn.Linear(att_dim, 1)
|
|
|
|
|
|
|
|
self.dunits = dunits
|
|
|
|
self.eprojs = eprojs
|
|
|
|
self.att_dim = att_dim
|
|
|
|
self.h_length = None
|
|
|
|
self.enc_h = None
|
|
|
|
self.pre_compute_enc_h = None
|
|
|
|
self.mask = None
|
|
|
|
self.han_mode = han_mode
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
"""reset states"""
|
|
|
|
self.h_length = None
|
|
|
|
self.enc_h = None
|
|
|
|
self.pre_compute_enc_h = None
|
|
|
|
self.mask = None
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
enc_hs_pad,
|
|
|
|
enc_hs_len,
|
|
|
|
dec_z,
|
|
|
|
att_prev,
|
|
|
|
scaling=2.0,
|
|
|
|
last_attended_idx=None,
|
|
|
|
backward_window=1,
|
|
|
|
forward_window=3, ):
|
|
|
|
"""Calculate AttLoc forward propagation.
|
|
|
|
Args:
|
|
|
|
enc_hs_pad(Tensor): padded encoder hidden state (B, T_max, D_enc)
|
|
|
|
enc_hs_len(Tensor): padded encoder hidden state length (B)
|
|
|
|
dec_z(Tensor dec_z): decoder hidden state (B, D_dec)
|
|
|
|
att_prev(Tensor): previous attention weight (B, T_max)
|
|
|
|
scaling(float, optional): scaling parameter before applying softmax (Default value = 2.0)
|
|
|
|
forward_window(Tensor, optional): forward window size when constraining attention (Default value = 3)
|
|
|
|
last_attended_idx(int, optional): index of the inputs of the last attended (Default value = None)
|
|
|
|
backward_window(int, optional): backward window size in attention constraint (Default value = 1)
|
|
|
|
forward_window(int, optional): forward window size in attetion constraint (Default value = 3)
|
|
|
|
Returns:
|
|
|
|
Tensor: attention weighted encoder state (B, D_enc)
|
|
|
|
Tensor: previous attention weights (B, T_max)
|
|
|
|
"""
|
|
|
|
batch = paddle.shape(enc_hs_pad)[0]
|
|
|
|
# pre-compute all h outside the decoder loop
|
|
|
|
if self.pre_compute_enc_h is None or self.han_mode:
|
|
|
|
# (utt, frame, hdim)
|
|
|
|
self.enc_h = enc_hs_pad
|
|
|
|
self.h_length = paddle.shape(self.enc_h)[1]
|
|
|
|
# (utt, frame, att_dim)
|
|
|
|
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
|
|
|
|
|
|
|
|
if dec_z is None:
|
|
|
|
dec_z = paddle.zeros([batch, self.dunits])
|
|
|
|
else:
|
|
|
|
dec_z = dec_z.reshape([batch, self.dunits])
|
|
|
|
|
|
|
|
# initialize attention weight with uniform dist.
|
|
|
|
if paddle.sum(att_prev) == 0:
|
|
|
|
# if no bias, 0 0-pad goes 0
|
|
|
|
att_prev = 1.0 - make_pad_mask(enc_hs_len)
|
|
|
|
att_prev = att_prev / enc_hs_len.unsqueeze(-1)
|
|
|
|
|
|
|
|
# att_prev: (utt, frame) -> (utt, 1, 1, frame)
|
|
|
|
# -> (utt, att_conv_chans, 1, frame)
|
|
|
|
att_conv = self.loc_conv(att_prev.reshape([batch, 1, 1, self.h_length]))
|
|
|
|
# att_conv: (utt, att_conv_chans, 1, frame) -> (utt, frame, att_conv_chans)
|
|
|
|
att_conv = att_conv.squeeze(2).transpose([0, 2, 1])
|
|
|
|
# att_conv: (utt, frame, att_conv_chans) -> (utt, frame, att_dim)
|
|
|
|
att_conv = self.mlp_att(att_conv)
|
|
|
|
# dec_z_tiled: (utt, frame, att_dim)
|
|
|
|
dec_z_tiled = self.mlp_dec(dec_z).reshape([batch, 1, self.att_dim])
|
|
|
|
|
|
|
|
# dot with gvec
|
|
|
|
# (utt, frame, att_dim) -> (utt, frame)
|
|
|
|
e = paddle.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
|
|
|
|
e = self.gvec(e).squeeze(2)
|
|
|
|
|
|
|
|
# NOTE: consider zero padding when compute w.
|
|
|
|
if self.mask is None:
|
|
|
|
self.mask = make_pad_mask(enc_hs_len)
|
|
|
|
|
|
|
|
e = masked_fill(e, self.mask, -float("inf"))
|
|
|
|
# apply monotonic attention constraint (mainly for TTS)
|
|
|
|
if last_attended_idx is not None:
|
|
|
|
e = _apply_attention_constraint(e, last_attended_idx,
|
|
|
|
backward_window, forward_window)
|
|
|
|
|
|
|
|
w = F.softmax(scaling * e, axis=1)
|
|
|
|
|
|
|
|
# weighted sum over frames
|
|
|
|
# utt x hdim
|
|
|
|
c = paddle.sum(
|
|
|
|
self.enc_h * w.reshape([batch, self.h_length, 1]), axis=1)
|
|
|
|
return c, w
|
|
|
|
|
|
|
|
|
|
|
|
class AttForward(nn.Layer):
|
|
|
|
"""Forward attention module.
|
|
|
|
Reference
|
|
|
|
----------
|
|
|
|
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
|
|
|
|
(https://arxiv.org/pdf/1807.06736.pdf)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
eprojs (int): projection-units of encoder
|
|
|
|
dunits (int): units of decoder
|
|
|
|
att_dim (int): attention dimension
|
|
|
|
aconv_chans (int): channels of attention convolution
|
|
|
|
aconv_filts (int): filter size of attention convolution
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts):
|
|
|
|
super().__init__()
|
|
|
|
self.mlp_enc = nn.Linear(eprojs, att_dim)
|
|
|
|
self.mlp_dec = nn.Linear(dunits, att_dim, bias_attr=False)
|
|
|
|
self.mlp_att = nn.Linear(aconv_chans, att_dim, bias_attr=False)
|
|
|
|
self.loc_conv = nn.Conv2D(
|
|
|
|
1,
|
|
|
|
aconv_chans,
|
|
|
|
(1, 2 * aconv_filts + 1),
|
|
|
|
padding=(0, aconv_filts),
|
|
|
|
bias_attr=False, )
|
|
|
|
self.gvec = nn.Linear(att_dim, 1)
|
|
|
|
self.dunits = dunits
|
|
|
|
self.eprojs = eprojs
|
|
|
|
self.att_dim = att_dim
|
|
|
|
self.h_length = None
|
|
|
|
self.enc_h = None
|
|
|
|
self.pre_compute_enc_h = None
|
|
|
|
self.mask = None
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
"""reset states"""
|
|
|
|
self.h_length = None
|
|
|
|
self.enc_h = None
|
|
|
|
self.pre_compute_enc_h = None
|
|
|
|
self.mask = None
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
enc_hs_pad,
|
|
|
|
enc_hs_len,
|
|
|
|
dec_z,
|
|
|
|
att_prev,
|
|
|
|
scaling=1.0,
|
|
|
|
last_attended_idx=None,
|
|
|
|
backward_window=1,
|
|
|
|
forward_window=3, ):
|
|
|
|
"""Calculate AttForward forward propagation.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
enc_hs_pad(Tensor): padded encoder hidden state (B, T_max, D_enc)
|
|
|
|
enc_hs_len(list): padded encoder hidden state length (B,)
|
|
|
|
dec_z(Tensor): decoder hidden state (B, D_dec)
|
|
|
|
att_prev(Tensor): attention weights of previous step (B, T_max)
|
|
|
|
scaling(float, optional): scaling parameter before applying softmax (Default value = 1.0)
|
|
|
|
last_attended_idx(int, optional): index of the inputs of the last attended (Default value = None)
|
|
|
|
backward_window(int, optional): backward window size in attention constraint (Default value = 1)
|
|
|
|
forward_window(int, optional): (Default value = 3)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor: attention weighted encoder state (B, D_enc)
|
|
|
|
Tensor: previous attention weights (B, T_max)
|
|
|
|
"""
|
|
|
|
batch = len(enc_hs_pad)
|
|
|
|
# pre-compute all h outside the decoder loop
|
|
|
|
if self.pre_compute_enc_h is None:
|
|
|
|
self.enc_h = enc_hs_pad # utt x frame x hdim
|
|
|
|
self.h_length = paddle.shape(self.enc_h)[1]
|
|
|
|
# utt x frame x att_dim
|
|
|
|
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
|
|
|
|
|
|
|
|
if dec_z is None:
|
|
|
|
dec_z = paddle.zeros([batch, self.dunits])
|
|
|
|
else:
|
|
|
|
dec_z = dec_z.reshape([batch, self.dunits])
|
|
|
|
|
|
|
|
if att_prev is None:
|
|
|
|
# initial attention will be [1, 0, 0, ...]
|
|
|
|
att_prev = paddle.zeros([*paddle.shape(enc_hs_pad)[:2]])
|
|
|
|
att_prev[:, 0] = 1.0
|
|
|
|
|
|
|
|
# att_prev: utt x frame -> utt x 1 x 1 x frame
|
|
|
|
# -> utt x att_conv_chans x 1 x frame
|
|
|
|
att_conv = self.loc_conv(att_prev.reshape([batch, 1, 1, self.h_length]))
|
|
|
|
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
|
|
|
|
att_conv = att_conv.squeeze(2).transpose([0, 2, 1])
|
|
|
|
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
|
|
|
|
att_conv = self.mlp_att(att_conv)
|
|
|
|
|
|
|
|
# dec_z_tiled: utt x frame x att_dim
|
|
|
|
dec_z_tiled = self.mlp_dec(dec_z).unsqueeze(1)
|
|
|
|
|
|
|
|
# dot with gvec
|
|
|
|
# utt x frame x att_dim -> utt x frame
|
|
|
|
e = self.gvec(
|
|
|
|
paddle.tanh(self.pre_compute_enc_h + dec_z_tiled +
|
|
|
|
att_conv)).squeeze(2)
|
|
|
|
|
|
|
|
# NOTE: consider zero padding when compute w.
|
|
|
|
if self.mask is None:
|
|
|
|
self.mask = make_pad_mask(enc_hs_len)
|
|
|
|
e = masked_fill(e, self.mask, -float("inf"))
|
|
|
|
|
|
|
|
# apply monotonic attention constraint (mainly for TTS)
|
|
|
|
if last_attended_idx is not None:
|
|
|
|
e = _apply_attention_constraint(e, last_attended_idx,
|
|
|
|
backward_window, forward_window)
|
|
|
|
|
|
|
|
w = F.softmax(scaling * e, axis=1)
|
|
|
|
|
|
|
|
# forward attention
|
|
|
|
att_prev_shift = F.pad(att_prev, (0, 0, 1, 0))[:, :-1]
|
|
|
|
|
|
|
|
w = (att_prev + att_prev_shift) * w
|
|
|
|
# NOTE: clip is needed to avoid nan gradient
|
|
|
|
w = F.normalize(paddle.clip(w, 1e-6), p=1, axis=1)
|
|
|
|
|
|
|
|
# weighted sum over flames
|
|
|
|
# utt x hdim
|
|
|
|
# NOTE use bmm instead of sum(*)
|
|
|
|
c = paddle.sum(self.enc_h * w.unsqueeze(-1), axis=1)
|
|
|
|
|
|
|
|
return c, w
|
|
|
|
|
|
|
|
|
|
|
|
class AttForwardTA(nn.Layer):
|
|
|
|
"""Forward attention with transition agent module.
|
|
|
|
Reference:
|
|
|
|
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
|
|
|
|
(https://arxiv.org/pdf/1807.06736.pdf)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
eunits (int): units of encoder
|
|
|
|
dunits (int): units of decoder
|
|
|
|
att_dim (int): attention dimension
|
|
|
|
aconv_chans (int): channels of attention convolution
|
|
|
|
aconv_filts (int): filter size of attention convolution
|
|
|
|
odim (int): output dimension
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, odim):
|
|
|
|
super().__init__()
|
|
|
|
self.mlp_enc = nn.Linear(eunits, att_dim)
|
|
|
|
self.mlp_dec = nn.Linear(dunits, att_dim, bias_attr=False)
|
|
|
|
self.mlp_ta = nn.Linear(eunits + dunits + odim, 1)
|
|
|
|
self.mlp_att = nn.Linear(aconv_chans, att_dim, bias_attr=False)
|
|
|
|
self.loc_conv = nn.Conv2D(
|
|
|
|
1,
|
|
|
|
aconv_chans,
|
|
|
|
(1, 2 * aconv_filts + 1),
|
|
|
|
padding=(0, aconv_filts),
|
|
|
|
bias_attr=False, )
|
|
|
|
self.gvec = nn.Linear(att_dim, 1)
|
|
|
|
self.dunits = dunits
|
|
|
|
self.eunits = eunits
|
|
|
|
self.att_dim = att_dim
|
|
|
|
self.h_length = None
|
|
|
|
self.enc_h = None
|
|
|
|
self.pre_compute_enc_h = None
|
|
|
|
self.mask = None
|
|
|
|
self.trans_agent_prob = 0.5
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.h_length = None
|
|
|
|
self.enc_h = None
|
|
|
|
self.pre_compute_enc_h = None
|
|
|
|
self.mask = None
|
|
|
|
self.trans_agent_prob = 0.5
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
enc_hs_pad,
|
|
|
|
enc_hs_len,
|
|
|
|
dec_z,
|
|
|
|
att_prev,
|
|
|
|
out_prev,
|
|
|
|
scaling=1.0,
|
|
|
|
last_attended_idx=None,
|
|
|
|
backward_window=1,
|
|
|
|
forward_window=3, ):
|
|
|
|
"""Calculate AttForwardTA forward propagation.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
enc_hs_pad(Tensor): padded encoder hidden state (B, Tmax, eunits)
|
|
|
|
enc_hs_len(list Tensor): padded encoder hidden state length (B,)
|
|
|
|
dec_z(Tensor): decoder hidden state (B, dunits)
|
|
|
|
att_prev(Tensor): attention weights of previous step (B, T_max)
|
|
|
|
out_prev(Tensor): decoder outputs of previous step (B, odim)
|
|
|
|
scaling(float, optional): scaling parameter before applying softmax (Default value = 1.0)
|
|
|
|
last_attended_idx(int, optional): index of the inputs of the last attended (Default value = None)
|
|
|
|
backward_window(int, optional): backward window size in attention constraint (Default value = 1)
|
|
|
|
forward_window(int, optional): (Default value = 3)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor: attention weighted encoder state (B, dunits)
|
|
|
|
Tensor: previous attention weights (B, Tmax)
|
|
|
|
"""
|
|
|
|
batch = len(enc_hs_pad)
|
|
|
|
# pre-compute all h outside the decoder loop
|
|
|
|
if self.pre_compute_enc_h is None:
|
|
|
|
self.enc_h = enc_hs_pad # utt x frame x hdim
|
|
|
|
self.h_length = paddle.shape(self.enc_h)[1]
|
|
|
|
# utt x frame x att_dim
|
|
|
|
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
|
|
|
|
|
|
|
|
if dec_z is None:
|
|
|
|
dec_z = paddle.zeros([batch, self.dunits])
|
|
|
|
else:
|
|
|
|
dec_z = dec_z.reshape([batch, self.dunits])
|
|
|
|
|
|
|
|
if att_prev is None:
|
|
|
|
# initial attention will be [1, 0, 0, ...]
|
|
|
|
att_prev = paddle.zeros([*paddle.shape(enc_hs_pad)[:2]])
|
|
|
|
att_prev[:, 0] = 1.0
|
|
|
|
|
|
|
|
# att_prev: utt x frame -> utt x 1 x 1 x frame
|
|
|
|
# -> utt x att_conv_chans x 1 x frame
|
|
|
|
att_conv = self.loc_conv(att_prev.reshape([batch, 1, 1, self.h_length]))
|
|
|
|
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
|
|
|
|
att_conv = att_conv.squeeze(2).transpose([0, 2, 1])
|
|
|
|
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
|
|
|
|
att_conv = self.mlp_att(att_conv)
|
|
|
|
|
|
|
|
# dec_z_tiled: utt x frame x att_dim
|
|
|
|
dec_z_tiled = self.mlp_dec(dec_z).reshape([batch, 1, self.att_dim])
|
|
|
|
|
|
|
|
# dot with gvec
|
|
|
|
# utt x frame x att_dim -> utt x frame
|
|
|
|
e = self.gvec(
|
|
|
|
paddle.tanh(att_conv + self.pre_compute_enc_h +
|
|
|
|
dec_z_tiled)).squeeze(2)
|
|
|
|
|
|
|
|
# NOTE consider zero padding when compute w.
|
|
|
|
if self.mask is None:
|
|
|
|
self.mask = make_pad_mask(enc_hs_len)
|
|
|
|
e = masked_fill(e, self.mask, -float("inf"))
|
|
|
|
|
|
|
|
# apply monotonic attention constraint (mainly for TTS)
|
|
|
|
if last_attended_idx is not None:
|
|
|
|
e = _apply_attention_constraint(e, last_attended_idx,
|
|
|
|
backward_window, forward_window)
|
|
|
|
|
|
|
|
w = F.softmax(scaling * e, axis=1)
|
|
|
|
|
|
|
|
# forward attention
|
|
|
|
# att_prev_shift = F.pad(att_prev.unsqueeze(0), (1, 0), data_format='NCL').squeeze(0)[:, :-1]
|
|
|
|
att_prev_shift = F.pad(att_prev, (0, 0, 1, 0))[:, :-1]
|
|
|
|
w = (self.trans_agent_prob * att_prev +
|
|
|
|
(1 - self.trans_agent_prob) * att_prev_shift) * w
|
|
|
|
# NOTE: clip is needed to avoid nan gradient
|
|
|
|
w = F.normalize(paddle.clip(w, 1e-6), p=1, axis=1)
|
|
|
|
|
|
|
|
# weighted sum over flames
|
|
|
|
# utt x hdim
|
|
|
|
# NOTE use bmm instead of sum(*)
|
|
|
|
c = paddle.sum(
|
|
|
|
self.enc_h * w.reshape([batch, self.h_length, 1]), axis=1)
|
|
|
|
|
|
|
|
# update transition agent prob
|
|
|
|
self.trans_agent_prob = F.sigmoid(
|
|
|
|
self.mlp_ta(paddle.concat([c, out_prev, dec_z], axis=1)))
|
|
|
|
|
|
|
|
return c, w
|