diff --git a/deepspeech/modules/crf.py b/deepspeech/modules/crf.py index 4bdc5a90..b6b481a0 100644 --- a/deepspeech/modules/crf.py +++ b/deepspeech/modules/crf.py @@ -24,7 +24,7 @@ __all__ = ['CRF'] class CRF(nn.Layer): """ Linear-chain Conditional Random Field (CRF). - + Args: nb_labels (int): number of labels in your tagset, including special symbols. bos_tag_id (int): integer representing the beginning of sentence symbol in @@ -162,15 +162,15 @@ class CRF(nn.Layer): # save first and last tags to be used later first_tags = tags[:, 0] last_valid_idx = mask.int().sum(1) - 1 + # TODO(Hui Zhang): not support fancy index. # last_tags = tags.gather(last_valid_idx.unsqueeze(1), axis=1).squeeze() - batch_idx = paddle.arange(batch_size) - gather_last_valid_idx = paddle.to_tensor( - list(zip(batch_idx.tolist(), last_valid_idx.tolist()))) + batch_idx = paddle.arange(batch_size, dtype=last_valid_idx.dtype) + gather_last_valid_idx = paddle.stack( + [batch_idx, last_valid_idx], axis=-1) last_tags = tags.gather_nd(gather_last_valid_idx) # add the transition from BOS to the first tags for each batch - # TODO(Hui Zhang): not support fancy index. # t_scores = self.transitions[self.BOS_TAG_ID, first_tags] t_scores = self.transitions[self.BOS_TAG_ID].gather(first_tags) @@ -178,10 +178,8 @@ class CRF(nn.Layer): # for all batches, the first word, see the correspondent emissions # for the first tags (which is a list of ids): # emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]] - # TODO(Hui Zhang): not support fancy index. # e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze() - gather_first_tags_idx = paddle.to_tensor( - list(zip(batch_idx.tolist(), first_tags.tolist()))) + gather_first_tags_idx = paddle.stack([batch_idx, first_tags], axis=-1) e_scores = emissions[:, 0].gather_nd(gather_first_tags_idx) # the scores for a word is just the sum of both scores @@ -199,15 +197,13 @@ class CRF(nn.Layer): current_tags = tags[:, i] # calculate emission and transition scores as we did before - # TODO(Hui Zhang): not support fancy index. # e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze() - gather_current_tags_idx = paddle.to_tensor( - list(zip(batch_idx.tolist(), current_tags.tolist()))) + gather_current_tags_idx = paddle.stack( + [batch_idx, current_tags], axis=-1) e_scores = emissions[:, i].gather_nd(gather_current_tags_idx) - # TODO(Hui Zhang): not support fancy index. # t_scores = self.transitions[previous_tags, current_tags] - gather_transitions_idx = paddle.to_tensor( - list(zip(previous_tags.tolist(), current_tags.tolist()))) + gather_transitions_idx = paddle.stack( + [previous_tags, current_tags], axis=-1) t_scores = self.transitions.gather_nd(gather_transitions_idx) # apply the mask @@ -300,7 +296,6 @@ class CRF(nn.Layer): # so far is exactly like the forward algorithm, # but now, instead of calculating the logsumexp, # we will find the highest score and the tag associated with it - # TODO(Hui Zhang): max not support return score and index. # max_scores, max_score_tags = paddle.max(scores, axis=1) max_scores = paddle.max(scores, axis=1) max_score_tags = paddle.argmax(scores, axis=1) @@ -319,7 +314,6 @@ class CRF(nn.Layer): end_scores = alphas + last_transition.unsqueeze(0) # get the final most probable score and the final most probable tag - # TODO(Hui Zhang): max not support return score and index. # max_final_scores, max_final_tags = paddle.max(end_scores, axis=1) max_final_scores = paddle.max(end_scores, axis=1) max_final_tags = paddle.argmax(end_scores, axis=1)