replace list zip by stack

pull/651/head
Hui Zhang 3 years ago
parent 34689bd1df
commit 4acaaba349

@ -24,7 +24,7 @@ __all__ = ['CRF']
class CRF(nn.Layer):
"""
Linear-chain Conditional Random Field (CRF).
Args:
nb_labels (int): number of labels in your tagset, including special symbols.
bos_tag_id (int): integer representing the beginning of sentence symbol in
@ -162,15 +162,15 @@ class CRF(nn.Layer):
# save first and last tags to be used later
first_tags = tags[:, 0]
last_valid_idx = mask.int().sum(1) - 1
# TODO(Hui Zhang): not support fancy index.
# last_tags = tags.gather(last_valid_idx.unsqueeze(1), axis=1).squeeze()
batch_idx = paddle.arange(batch_size)
gather_last_valid_idx = paddle.to_tensor(
list(zip(batch_idx.tolist(), last_valid_idx.tolist())))
batch_idx = paddle.arange(batch_size, dtype=last_valid_idx.dtype)
gather_last_valid_idx = paddle.stack(
[batch_idx, last_valid_idx], axis=-1)
last_tags = tags.gather_nd(gather_last_valid_idx)
# add the transition from BOS to the first tags for each batch
# TODO(Hui Zhang): not support fancy index.
# t_scores = self.transitions[self.BOS_TAG_ID, first_tags]
t_scores = self.transitions[self.BOS_TAG_ID].gather(first_tags)
@ -178,10 +178,8 @@ class CRF(nn.Layer):
# for all batches, the first word, see the correspondent emissions
# for the first tags (which is a list of ids):
# emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]]
# TODO(Hui Zhang): not support fancy index.
# e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze()
gather_first_tags_idx = paddle.to_tensor(
list(zip(batch_idx.tolist(), first_tags.tolist())))
gather_first_tags_idx = paddle.stack([batch_idx, first_tags], axis=-1)
e_scores = emissions[:, 0].gather_nd(gather_first_tags_idx)
# the scores for a word is just the sum of both scores
@ -199,15 +197,13 @@ class CRF(nn.Layer):
current_tags = tags[:, i]
# calculate emission and transition scores as we did before
# TODO(Hui Zhang): not support fancy index.
# e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze()
gather_current_tags_idx = paddle.to_tensor(
list(zip(batch_idx.tolist(), current_tags.tolist())))
gather_current_tags_idx = paddle.stack(
[batch_idx, current_tags], axis=-1)
e_scores = emissions[:, i].gather_nd(gather_current_tags_idx)
# TODO(Hui Zhang): not support fancy index.
# t_scores = self.transitions[previous_tags, current_tags]
gather_transitions_idx = paddle.to_tensor(
list(zip(previous_tags.tolist(), current_tags.tolist())))
gather_transitions_idx = paddle.stack(
[previous_tags, current_tags], axis=-1)
t_scores = self.transitions.gather_nd(gather_transitions_idx)
# apply the mask
@ -300,7 +296,6 @@ class CRF(nn.Layer):
# so far is exactly like the forward algorithm,
# but now, instead of calculating the logsumexp,
# we will find the highest score and the tag associated with it
# TODO(Hui Zhang): max not support return score and index.
# max_scores, max_score_tags = paddle.max(scores, axis=1)
max_scores = paddle.max(scores, axis=1)
max_score_tags = paddle.argmax(scores, axis=1)
@ -319,7 +314,6 @@ class CRF(nn.Layer):
end_scores = alphas + last_transition.unsqueeze(0)
# get the final most probable score and the final most probable tag
# TODO(Hui Zhang): max not support return score and index.
# max_final_scores, max_final_tags = paddle.max(end_scores, axis=1)
max_final_scores = paddle.max(end_scores, axis=1)
max_final_tags = paddle.argmax(end_scores, axis=1)

Loading…
Cancel
Save