|
|
@ -11,6 +11,7 @@
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import logging
|
|
|
|
from typing import Any
|
|
|
|
from typing import Any
|
|
|
|
from typing import List
|
|
|
|
from typing import List
|
|
|
|
from typing import Tuple
|
|
|
|
from typing import Tuple
|
|
|
@ -20,12 +21,12 @@ import paddle
|
|
|
|
import paddle.nn as nn
|
|
|
|
import paddle.nn as nn
|
|
|
|
import paddle.nn.functional as F
|
|
|
|
import paddle.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
from deepspeech.modules.mask import subsequent_mask
|
|
|
|
|
|
|
|
from deepspeech.modules.encoder import TransformerEncoder
|
|
|
|
|
|
|
|
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
|
|
|
|
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
|
|
|
|
from deepspeech.models.lm_interface import LMInterface
|
|
|
|
from deepspeech.models.lm_interface import LMInterface
|
|
|
|
|
|
|
|
from deepspeech.modules.encoder import TransformerEncoder
|
|
|
|
|
|
|
|
from deepspeech.modules.mask import subsequent_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
|
|
|
|
class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
@ -84,15 +85,12 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
|
|
|
|
), "Tie Weights: True need embedding and final dimensions to match"
|
|
|
|
), "Tie Weights: True need embedding and final dimensions to match"
|
|
|
|
self.decoder.weight = self.embed.weight
|
|
|
|
self.decoder.weight = self.embed.weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _target_mask(self, ys_in_pad):
|
|
|
|
def _target_mask(self, ys_in_pad):
|
|
|
|
ys_mask = ys_in_pad != 0
|
|
|
|
ys_mask = ys_in_pad != 0
|
|
|
|
m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0)
|
|
|
|
m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0)
|
|
|
|
return ys_mask.unsqueeze(-2) & m
|
|
|
|
return ys_mask.unsqueeze(-2) & m
|
|
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
def forward(self, x: paddle.Tensor, t: paddle.Tensor
|
|
|
|
self, x: paddle.Tensor, t: paddle.Tensor
|
|
|
|
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
|
|
"""Compute LM loss value from buffer sequences.
|
|
|
|
"""Compute LM loss value from buffer sequences.
|
|
|
|
|
|
|
|
|
|
|
@ -119,7 +117,8 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
|
|
|
|
emb = self.embed(x)
|
|
|
|
emb = self.embed(x)
|
|
|
|
h, _ = self.encoder(emb, xlen)
|
|
|
|
h, _ = self.encoder(emb, xlen)
|
|
|
|
y = self.decoder(h)
|
|
|
|
y = self.decoder(h)
|
|
|
|
loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
|
|
|
|
loss = F.cross_entropy(
|
|
|
|
|
|
|
|
y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
|
|
|
|
mask = xm.to(dtype=loss.dtype)
|
|
|
|
mask = xm.to(dtype=loss.dtype)
|
|
|
|
logp = loss * mask.view(-1)
|
|
|
|
logp = loss * mask.view(-1)
|
|
|
|
logp = logp.sum()
|
|
|
|
logp = logp.sum()
|
|
|
@ -150,16 +149,16 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
|
|
|
|
emb = self.embed(y)
|
|
|
|
emb = self.embed(y)
|
|
|
|
|
|
|
|
|
|
|
|
h, _, cache = self.encoder.forward_one_step(
|
|
|
|
h, _, cache = self.encoder.forward_one_step(
|
|
|
|
emb, self._target_mask(y), cache=state
|
|
|
|
emb, self._target_mask(y), cache=state)
|
|
|
|
)
|
|
|
|
|
|
|
|
h = self.decoder(h[:, -1])
|
|
|
|
h = self.decoder(h[:, -1])
|
|
|
|
logp = F.log_softmax(h).squeeze(0)
|
|
|
|
logp = F.log_softmax(h).squeeze(0)
|
|
|
|
return logp, cache
|
|
|
|
return logp, cache
|
|
|
|
|
|
|
|
|
|
|
|
# batch beam search API (see BatchScorerInterface)
|
|
|
|
# batch beam search API (see BatchScorerInterface)
|
|
|
|
def batch_score(
|
|
|
|
def batch_score(self,
|
|
|
|
self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor
|
|
|
|
ys: paddle.Tensor,
|
|
|
|
) -> Tuple[paddle.Tensor, List[Any]]:
|
|
|
|
states: List[Any],
|
|
|
|
|
|
|
|
xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]:
|
|
|
|
"""Score new token batch (required).
|
|
|
|
"""Score new token batch (required).
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
@ -193,13 +192,13 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
|
|
|
|
|
|
|
|
|
|
|
|
# batch decoding
|
|
|
|
# batch decoding
|
|
|
|
h, _, states = self.encoder.forward_one_step(
|
|
|
|
h, _, states = self.encoder.forward_one_step(
|
|
|
|
emb, self._target_mask(ys), cache=batch_state
|
|
|
|
emb, self._target_mask(ys), cache=batch_state)
|
|
|
|
)
|
|
|
|
|
|
|
|
h = self.decoder(h[:, -1])
|
|
|
|
h = self.decoder(h[:, -1])
|
|
|
|
logp = F.log_softmax(h)
|
|
|
|
logp = F.log_softmax(h)
|
|
|
|
|
|
|
|
|
|
|
|
# transpose state of [layer, batch] into [batch, layer]
|
|
|
|
# transpose state of [layer, batch] into [batch, layer]
|
|
|
|
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
|
|
|
|
state_list = [[states[i][b] for i in range(n_layers)]
|
|
|
|
|
|
|
|
for b in range(n_batch)]
|
|
|
|
return logp, state_list
|
|
|
|
return logp, state_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|