You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
371 lines
15 KiB
371 lines
15 KiB
4 years ago
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
import paddle
|
||
|
from paddle import nn
|
||
|
|
||
3 years ago
|
from paddlespeech.s2t.utils.log import Log
|
||
4 years ago
|
|
||
|
logger = Log(__name__).getlog()
|
||
|
|
||
|
__all__ = ['CRF']
|
||
|
|
||
|
|
||
|
class CRF(nn.Layer):
|
||
|
"""
|
||
|
Linear-chain Conditional Random Field (CRF).
|
||
4 years ago
|
|
||
4 years ago
|
Args:
|
||
|
nb_labels (int): number of labels in your tagset, including special symbols.
|
||
|
bos_tag_id (int): integer representing the beginning of sentence symbol in
|
||
|
your tagset.
|
||
|
eos_tag_id (int): integer representing the end of sentence symbol in your tagset.
|
||
|
pad_tag_id (int, optional): integer representing the pad symbol in your tagset.
|
||
|
If None, the model will treat the PAD as a normal tag. Otherwise, the model
|
||
|
will apply constraints for PAD transitions.
|
||
|
batch_first (bool): Whether the first dimension represents the batch dimension.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
nb_labels: int,
|
||
|
bos_tag_id: int,
|
||
|
eos_tag_id: int,
|
||
|
pad_tag_id: int=None,
|
||
|
batch_first: bool=True):
|
||
|
super().__init__()
|
||
|
|
||
|
self.nb_labels = nb_labels
|
||
|
self.BOS_TAG_ID = bos_tag_id
|
||
|
self.EOS_TAG_ID = eos_tag_id
|
||
|
self.PAD_TAG_ID = pad_tag_id
|
||
|
self.batch_first = batch_first
|
||
|
|
||
|
# initialize transitions from a random uniform distribution between -0.1 and 0.1
|
||
|
self.transitions = self.create_parameter(
|
||
|
[self.nb_labels, self.nb_labels],
|
||
|
default_initializer=nn.initializer.Uniform(-0.1, 0.1))
|
||
|
self.init_weights()
|
||
|
|
||
|
def init_weights(self):
|
||
|
# enforce contraints (rows=from, columns=to) with a big negative number
|
||
|
# so exp(-10000) will tend to zero
|
||
|
|
||
|
# no transitions allowed to the beginning of sentence
|
||
|
self.transitions[:, self.BOS_TAG_ID] = -10000.0
|
||
|
# no transition alloed from the end of sentence
|
||
|
self.transitions[self.EOS_TAG_ID, :] = -10000.0
|
||
|
|
||
|
if self.PAD_TAG_ID is not None:
|
||
|
# no transitions from padding
|
||
|
self.transitions[self.PAD_TAG_ID, :] = -10000.0
|
||
|
# no transitions to padding
|
||
|
self.transitions[:, self.PAD_TAG_ID] = -10000.0
|
||
|
# except if the end of sentence is reached
|
||
|
# or we are already in a pad position
|
||
|
self.transitions[self.PAD_TAG_ID, self.EOS_TAG_ID] = 0.0
|
||
|
self.transitions[self.PAD_TAG_ID, self.PAD_TAG_ID] = 0.0
|
||
|
|
||
|
def forward(self,
|
||
|
emissions: paddle.Tensor,
|
||
|
tags: paddle.Tensor,
|
||
|
mask: paddle.Tensor=None) -> paddle.Tensor:
|
||
|
"""Compute the negative log-likelihood. See `log_likelihood` method."""
|
||
|
nll = -self.log_likelihood(emissions, tags, mask=mask)
|
||
|
return nll
|
||
|
|
||
|
def log_likelihood(self, emissions, tags, mask=None):
|
||
|
"""Compute the probability of a sequence of tags given a sequence of
|
||
|
emissions scores.
|
||
|
|
||
|
Args:
|
||
|
emissions (paddle.Tensor): Sequence of emissions for each label.
|
||
|
Shape of (batch_size, seq_len, nb_labels) if batch_first is True,
|
||
|
(seq_len, batch_size, nb_labels) otherwise.
|
||
|
tags (paddle.LongTensor): Sequence of labels.
|
||
|
Shape of (batch_size, seq_len) if batch_first is True,
|
||
|
(seq_len, batch_size) otherwise.
|
||
|
mask (paddle.FloatTensor, optional): Tensor representing valid positions.
|
||
|
If None, all positions are considered valid.
|
||
|
Shape of (batch_size, seq_len) if batch_first is True,
|
||
|
(seq_len, batch_size) otherwise.
|
||
|
|
||
|
Returns:
|
||
|
paddle.Tensor: sum of the log-likelihoods for each sequence in the batch.
|
||
|
Shape of ()
|
||
|
"""
|
||
|
# fix tensors order by setting batch as the first dimension
|
||
|
if not self.batch_first:
|
||
|
emissions = emissions.transpose(0, 1)
|
||
|
tags = tags.transpose(0, 1)
|
||
|
|
||
|
if mask is None:
|
||
|
mask = paddle.ones(emissions.shape[:2], dtype=paddle.float)
|
||
|
|
||
|
scores = self._compute_scores(emissions, tags, mask=mask)
|
||
|
partition = self._compute_log_partition(emissions, mask=mask)
|
||
|
return paddle.sum(scores - partition)
|
||
|
|
||
|
def decode(self, emissions, mask=None):
|
||
|
"""Find the most probable sequence of labels given the emissions using
|
||
|
the Viterbi algorithm.
|
||
|
|
||
|
Args:
|
||
|
emissions (paddle.Tensor): Sequence of emissions for each label.
|
||
|
Shape (batch_size, seq_len, nb_labels) if batch_first is True,
|
||
|
(seq_len, batch_size, nb_labels) otherwise.
|
||
|
mask (paddle.FloatTensor, optional): Tensor representing valid positions.
|
||
|
If None, all positions are considered valid.
|
||
|
Shape (batch_size, seq_len) if batch_first is True,
|
||
|
(seq_len, batch_size) otherwise.
|
||
|
|
||
|
Returns:
|
||
|
paddle.Tensor: the viterbi score for the for each batch.
|
||
|
Shape of (batch_size,)
|
||
|
list of lists: the best viterbi sequence of labels for each batch. [B, T]
|
||
|
"""
|
||
|
# fix tensors order by setting batch as the first dimension
|
||
|
if not self.batch_first:
|
||
|
emissions = emissions.transpose(0, 1)
|
||
|
tags = tags.transpose(0, 1)
|
||
|
|
||
|
if mask is None:
|
||
|
mask = paddle.ones(emissions.shape[:2], dtype=paddle.float)
|
||
|
|
||
|
scores, sequences = self._viterbi_decode(emissions, mask)
|
||
|
return scores, sequences
|
||
|
|
||
|
def _compute_scores(self, emissions, tags, mask):
|
||
|
"""Compute the scores for a given batch of emissions with their tags.
|
||
|
|
||
|
Args:
|
||
|
emissions (paddle.Tensor): (batch_size, seq_len, nb_labels)
|
||
|
tags (Paddle.LongTensor): (batch_size, seq_len)
|
||
|
mask (Paddle.FloatTensor): (batch_size, seq_len)
|
||
|
|
||
|
Returns:
|
||
|
paddle.Tensor: Scores for each batch.
|
||
|
Shape of (batch_size,)
|
||
|
"""
|
||
|
batch_size, seq_length = tags.shape
|
||
|
scores = paddle.zeros([batch_size])
|
||
|
|
||
|
# save first and last tags to be used later
|
||
|
first_tags = tags[:, 0]
|
||
|
last_valid_idx = mask.int().sum(1) - 1
|
||
4 years ago
|
|
||
4 years ago
|
# TODO(Hui Zhang): not support fancy index.
|
||
|
# last_tags = tags.gather(last_valid_idx.unsqueeze(1), axis=1).squeeze()
|
||
4 years ago
|
batch_idx = paddle.arange(batch_size, dtype=last_valid_idx.dtype)
|
||
|
gather_last_valid_idx = paddle.stack(
|
||
|
[batch_idx, last_valid_idx], axis=-1)
|
||
4 years ago
|
last_tags = tags.gather_nd(gather_last_valid_idx)
|
||
|
|
||
|
# add the transition from BOS to the first tags for each batch
|
||
|
# t_scores = self.transitions[self.BOS_TAG_ID, first_tags]
|
||
|
t_scores = self.transitions[self.BOS_TAG_ID].gather(first_tags)
|
||
|
|
||
|
# add the [unary] emission scores for the first tags for each batch
|
||
|
# for all batches, the first word, see the correspondent emissions
|
||
|
# for the first tags (which is a list of ids):
|
||
|
# emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]]
|
||
|
# e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze()
|
||
4 years ago
|
gather_first_tags_idx = paddle.stack([batch_idx, first_tags], axis=-1)
|
||
4 years ago
|
e_scores = emissions[:, 0].gather_nd(gather_first_tags_idx)
|
||
|
|
||
|
# the scores for a word is just the sum of both scores
|
||
|
scores += e_scores + t_scores
|
||
|
|
||
|
# now lets do this for each remaining word
|
||
|
for i in range(1, seq_length):
|
||
|
|
||
|
# we could: iterate over batches, check if we reached a mask symbol
|
||
|
# and stop the iteration, but vecotrizing is faster due to gpu,
|
||
|
# so instead we perform an element-wise multiplication
|
||
|
is_valid = mask[:, i]
|
||
|
|
||
|
previous_tags = tags[:, i - 1]
|
||
|
current_tags = tags[:, i]
|
||
|
|
||
|
# calculate emission and transition scores as we did before
|
||
|
# e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze()
|
||
4 years ago
|
gather_current_tags_idx = paddle.stack(
|
||
|
[batch_idx, current_tags], axis=-1)
|
||
4 years ago
|
e_scores = emissions[:, i].gather_nd(gather_current_tags_idx)
|
||
|
# t_scores = self.transitions[previous_tags, current_tags]
|
||
4 years ago
|
gather_transitions_idx = paddle.stack(
|
||
|
[previous_tags, current_tags], axis=-1)
|
||
4 years ago
|
t_scores = self.transitions.gather_nd(gather_transitions_idx)
|
||
|
|
||
|
# apply the mask
|
||
|
e_scores = e_scores * is_valid
|
||
|
t_scores = t_scores * is_valid
|
||
|
|
||
|
scores += e_scores + t_scores
|
||
|
|
||
|
# add the transition from the end tag to the EOS tag for each batch
|
||
|
# scores += self.transitions[last_tags, self.EOS_TAG_ID]
|
||
|
scores += self.transitions.gather(last_tags)[:, self.EOS_TAG_ID]
|
||
|
|
||
|
return scores
|
||
|
|
||
|
def _compute_log_partition(self, emissions, mask):
|
||
|
"""Compute the partition function in log-space using the forward-algorithm.
|
||
|
|
||
|
Args:
|
||
|
emissions (paddle.Tensor): (batch_size, seq_len, nb_labels)
|
||
|
mask (Paddle.FloatTensor): (batch_size, seq_len)
|
||
|
|
||
|
Returns:
|
||
|
paddle.Tensor: the partition scores for each batch.
|
||
|
Shape of (batch_size,)
|
||
|
"""
|
||
|
batch_size, seq_length, nb_labels = emissions.shape
|
||
|
|
||
|
# in the first iteration, BOS will have all the scores
|
||
|
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(
|
||
|
0) + emissions[:, 0]
|
||
|
|
||
|
for i in range(1, seq_length):
|
||
|
# (bs, nb_labels) -> (bs, 1, nb_labels)
|
||
|
e_scores = emissions[:, i].unsqueeze(1)
|
||
|
|
||
|
# (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels)
|
||
|
t_scores = self.transitions.unsqueeze(0)
|
||
|
|
||
|
# (bs, nb_labels) -> (bs, nb_labels, 1)
|
||
|
a_scores = alphas.unsqueeze(2)
|
||
|
|
||
|
scores = e_scores + t_scores + a_scores
|
||
|
new_alphas = paddle.logsumexp(scores, axis=1)
|
||
|
|
||
|
# set alphas if the mask is valid, otherwise keep the current values
|
||
|
is_valid = mask[:, i].unsqueeze(-1)
|
||
|
alphas = is_valid * new_alphas + (1 - is_valid) * alphas
|
||
|
|
||
|
# add the scores for the final transition
|
||
|
last_transition = self.transitions[:, self.EOS_TAG_ID]
|
||
|
end_scores = alphas + last_transition.unsqueeze(0)
|
||
|
|
||
|
# return a *log* of sums of exps
|
||
|
return paddle.logsumexp(end_scores, axis=1)
|
||
|
|
||
|
def _viterbi_decode(self, emissions, mask):
|
||
|
"""Compute the viterbi algorithm to find the most probable sequence of labels
|
||
|
given a sequence of emissions.
|
||
|
|
||
|
Args:
|
||
|
emissions (paddle.Tensor): (batch_size, seq_len, nb_labels)
|
||
|
mask (Paddle.FloatTensor): (batch_size, seq_len)
|
||
|
|
||
|
Returns:
|
||
|
paddle.Tensor: the viterbi score for the for each batch.
|
||
|
Shape of (batch_size,)
|
||
|
list of lists of ints: the best viterbi sequence of labels for each batch
|
||
|
"""
|
||
|
batch_size, seq_length, nb_labels = emissions.shape
|
||
|
|
||
|
# in the first iteration, BOS will have all the scores and then, the max
|
||
|
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(
|
||
|
0) + emissions[:, 0]
|
||
|
|
||
|
backpointers = []
|
||
|
|
||
|
for i in range(1, seq_length):
|
||
|
# (bs, nb_labels) -> (bs, 1, nb_labels)
|
||
|
e_scores = emissions[:, i].unsqueeze(1)
|
||
|
|
||
|
# (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels)
|
||
|
t_scores = self.transitions.unsqueeze(0)
|
||
|
|
||
|
# (bs, nb_labels) -> (bs, nb_labels, 1)
|
||
|
a_scores = alphas.unsqueeze(2)
|
||
|
|
||
|
# combine current scores with previous alphas
|
||
|
scores = e_scores + t_scores + a_scores
|
||
|
|
||
|
# so far is exactly like the forward algorithm,
|
||
|
# but now, instead of calculating the logsumexp,
|
||
|
# we will find the highest score and the tag associated with it
|
||
|
# max_scores, max_score_tags = paddle.max(scores, axis=1)
|
||
|
max_scores = paddle.max(scores, axis=1)
|
||
|
max_score_tags = paddle.argmax(scores, axis=1)
|
||
|
|
||
|
# set alphas if the mask is valid, otherwise keep the current values
|
||
|
is_valid = mask[:, i].unsqueeze(-1)
|
||
|
alphas = is_valid * max_scores + (1 - is_valid) * alphas
|
||
|
|
||
|
# add the max_score_tags for our list of backpointers
|
||
|
# max_scores has shape (batch_size, nb_labels) so we transpose it to
|
||
|
# be compatible with our previous loopy version of viterbi
|
||
|
backpointers.append(max_score_tags.t())
|
||
|
|
||
|
# add the scores for the final transition
|
||
|
last_transition = self.transitions[:, self.EOS_TAG_ID]
|
||
|
end_scores = alphas + last_transition.unsqueeze(0)
|
||
|
|
||
|
# get the final most probable score and the final most probable tag
|
||
|
# max_final_scores, max_final_tags = paddle.max(end_scores, axis=1)
|
||
|
max_final_scores = paddle.max(end_scores, axis=1)
|
||
|
max_final_tags = paddle.argmax(end_scores, axis=1)
|
||
|
|
||
|
# find the best sequence of labels for each sample in the batch
|
||
|
best_sequences = []
|
||
|
emission_lengths = mask.int().sum(axis=1)
|
||
|
for i in range(batch_size):
|
||
|
|
||
|
# recover the original sentence length for the i-th sample in the batch
|
||
|
sample_length = emission_lengths[i].item()
|
||
|
|
||
|
# recover the max tag for the last timestep
|
||
|
sample_final_tag = max_final_tags[i].item()
|
||
|
|
||
|
# limit the backpointers until the last but one
|
||
|
# since the last corresponds to the sample_final_tag
|
||
|
sample_backpointers = backpointers[:sample_length - 1]
|
||
|
|
||
|
# follow the backpointers to build the sequence of labels
|
||
|
sample_path = self._find_best_path(i, sample_final_tag,
|
||
|
sample_backpointers)
|
||
|
|
||
|
# add this path to the list of best sequences
|
||
|
best_sequences.append(sample_path)
|
||
|
|
||
|
return max_final_scores, best_sequences
|
||
|
|
||
|
def _find_best_path(self, sample_id, best_tag, backpointers):
|
||
|
"""Auxiliary function to find the best path sequence for a specific sample.
|
||
|
|
||
|
Args:
|
||
|
sample_id (int): sample index in the range [0, batch_size)
|
||
|
best_tag (int): tag which maximizes the final score
|
||
|
backpointers (list of lists of tensors): list of pointers with
|
||
|
shape (seq_len_i-1, nb_labels, batch_size) where seq_len_i
|
||
|
represents the length of the ith sample in the batch
|
||
|
|
||
|
Returns:
|
||
|
list of ints: a list of tag indexes representing the bast path
|
||
|
"""
|
||
|
# add the final best_tag to our best path
|
||
|
best_path = [best_tag]
|
||
|
|
||
|
# traverse the backpointers in backwards
|
||
|
for backpointers_t in reversed(backpointers):
|
||
|
|
||
|
# recover the best_tag at this timestep
|
||
|
best_tag = backpointers_t[best_tag][sample_id].item()
|
||
|
|
||
|
# append to the beginning of the list so we don't need to reverse it later
|
||
|
best_path.insert(0, best_tag)
|
||
|
|
||
|
return best_path
|