|
|
|
@ -162,15 +162,15 @@ class CRF(nn.Layer):
|
|
|
|
|
# save first and last tags to be used later
|
|
|
|
|
first_tags = tags[:, 0]
|
|
|
|
|
last_valid_idx = mask.int().sum(1) - 1
|
|
|
|
|
|
|
|
|
|
# TODO(Hui Zhang): not support fancy index.
|
|
|
|
|
# last_tags = tags.gather(last_valid_idx.unsqueeze(1), axis=1).squeeze()
|
|
|
|
|
batch_idx = paddle.arange(batch_size)
|
|
|
|
|
gather_last_valid_idx = paddle.to_tensor(
|
|
|
|
|
list(zip(batch_idx.tolist(), last_valid_idx.tolist())))
|
|
|
|
|
batch_idx = paddle.arange(batch_size, dtype=last_valid_idx.dtype)
|
|
|
|
|
gather_last_valid_idx = paddle.stack(
|
|
|
|
|
[batch_idx, last_valid_idx], axis=-1)
|
|
|
|
|
last_tags = tags.gather_nd(gather_last_valid_idx)
|
|
|
|
|
|
|
|
|
|
# add the transition from BOS to the first tags for each batch
|
|
|
|
|
# TODO(Hui Zhang): not support fancy index.
|
|
|
|
|
# t_scores = self.transitions[self.BOS_TAG_ID, first_tags]
|
|
|
|
|
t_scores = self.transitions[self.BOS_TAG_ID].gather(first_tags)
|
|
|
|
|
|
|
|
|
@ -178,10 +178,8 @@ class CRF(nn.Layer):
|
|
|
|
|
# for all batches, the first word, see the correspondent emissions
|
|
|
|
|
# for the first tags (which is a list of ids):
|
|
|
|
|
# emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]]
|
|
|
|
|
# TODO(Hui Zhang): not support fancy index.
|
|
|
|
|
# e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze()
|
|
|
|
|
gather_first_tags_idx = paddle.to_tensor(
|
|
|
|
|
list(zip(batch_idx.tolist(), first_tags.tolist())))
|
|
|
|
|
gather_first_tags_idx = paddle.stack([batch_idx, first_tags], axis=-1)
|
|
|
|
|
e_scores = emissions[:, 0].gather_nd(gather_first_tags_idx)
|
|
|
|
|
|
|
|
|
|
# the scores for a word is just the sum of both scores
|
|
|
|
@ -199,15 +197,13 @@ class CRF(nn.Layer):
|
|
|
|
|
current_tags = tags[:, i]
|
|
|
|
|
|
|
|
|
|
# calculate emission and transition scores as we did before
|
|
|
|
|
# TODO(Hui Zhang): not support fancy index.
|
|
|
|
|
# e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze()
|
|
|
|
|
gather_current_tags_idx = paddle.to_tensor(
|
|
|
|
|
list(zip(batch_idx.tolist(), current_tags.tolist())))
|
|
|
|
|
gather_current_tags_idx = paddle.stack(
|
|
|
|
|
[batch_idx, current_tags], axis=-1)
|
|
|
|
|
e_scores = emissions[:, i].gather_nd(gather_current_tags_idx)
|
|
|
|
|
# TODO(Hui Zhang): not support fancy index.
|
|
|
|
|
# t_scores = self.transitions[previous_tags, current_tags]
|
|
|
|
|
gather_transitions_idx = paddle.to_tensor(
|
|
|
|
|
list(zip(previous_tags.tolist(), current_tags.tolist())))
|
|
|
|
|
gather_transitions_idx = paddle.stack(
|
|
|
|
|
[previous_tags, current_tags], axis=-1)
|
|
|
|
|
t_scores = self.transitions.gather_nd(gather_transitions_idx)
|
|
|
|
|
|
|
|
|
|
# apply the mask
|
|
|
|
@ -300,7 +296,6 @@ class CRF(nn.Layer):
|
|
|
|
|
# so far is exactly like the forward algorithm,
|
|
|
|
|
# but now, instead of calculating the logsumexp,
|
|
|
|
|
# we will find the highest score and the tag associated with it
|
|
|
|
|
# TODO(Hui Zhang): max not support return score and index.
|
|
|
|
|
# max_scores, max_score_tags = paddle.max(scores, axis=1)
|
|
|
|
|
max_scores = paddle.max(scores, axis=1)
|
|
|
|
|
max_score_tags = paddle.argmax(scores, axis=1)
|
|
|
|
@ -319,7 +314,6 @@ class CRF(nn.Layer):
|
|
|
|
|
end_scores = alphas + last_transition.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
# get the final most probable score and the final most probable tag
|
|
|
|
|
# TODO(Hui Zhang): max not support return score and index.
|
|
|
|
|
# max_final_scores, max_final_tags = paddle.max(end_scores, axis=1)
|
|
|
|
|
max_final_scores = paddle.max(end_scores, axis=1)
|
|
|
|
|
max_final_tags = paddle.argmax(end_scores, axis=1)
|
|
|
|
|