From 34689bd1df3f3d0ba0db414f2bb905a68de2e64d Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 1 Jun 2021 08:25:55 +0000 Subject: [PATCH] add crf --- deepspeech/__init__.py | 9 + deepspeech/modules/crf.py | 376 ++++++++++++++++++++ examples/chinese_g2p/local/ignore_sandhi.py | 7 +- examples/librispeech/s0/README.md | 2 +- 4 files changed, 391 insertions(+), 3 deletions(-) create mode 100644 deepspeech/modules/crf.py diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index c942de0c..37531657 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -345,6 +345,15 @@ if not hasattr(paddle.Tensor, 'float'): setattr(paddle.Tensor, 'float', func_float) +def func_int(x: paddle.Tensor) -> paddle.Tensor: + return x.astype(paddle.int) + + +if not hasattr(paddle.Tensor, 'int'): + logger.warn("register user int to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'int', func_int) + + def tolist(x: paddle.Tensor) -> List[Any]: return x.numpy().tolist() diff --git a/deepspeech/modules/crf.py b/deepspeech/modules/crf.py new file mode 100644 index 00000000..4bdc5a90 --- /dev/null +++ b/deepspeech/modules/crf.py @@ -0,0 +1,376 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +from paddle import nn + +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + +__all__ = ['CRF'] + + +class CRF(nn.Layer): + """ + Linear-chain Conditional Random Field (CRF). + + Args: + nb_labels (int): number of labels in your tagset, including special symbols. + bos_tag_id (int): integer representing the beginning of sentence symbol in + your tagset. + eos_tag_id (int): integer representing the end of sentence symbol in your tagset. + pad_tag_id (int, optional): integer representing the pad symbol in your tagset. + If None, the model will treat the PAD as a normal tag. Otherwise, the model + will apply constraints for PAD transitions. + batch_first (bool): Whether the first dimension represents the batch dimension. + """ + + def __init__(self, + nb_labels: int, + bos_tag_id: int, + eos_tag_id: int, + pad_tag_id: int=None, + batch_first: bool=True): + super().__init__() + + self.nb_labels = nb_labels + self.BOS_TAG_ID = bos_tag_id + self.EOS_TAG_ID = eos_tag_id + self.PAD_TAG_ID = pad_tag_id + self.batch_first = batch_first + + # initialize transitions from a random uniform distribution between -0.1 and 0.1 + self.transitions = self.create_parameter( + [self.nb_labels, self.nb_labels], + default_initializer=nn.initializer.Uniform(-0.1, 0.1)) + self.init_weights() + + def init_weights(self): + # enforce contraints (rows=from, columns=to) with a big negative number + # so exp(-10000) will tend to zero + + # no transitions allowed to the beginning of sentence + self.transitions[:, self.BOS_TAG_ID] = -10000.0 + # no transition alloed from the end of sentence + self.transitions[self.EOS_TAG_ID, :] = -10000.0 + + if self.PAD_TAG_ID is not None: + # no transitions from padding + self.transitions[self.PAD_TAG_ID, :] = -10000.0 + # no transitions to padding + self.transitions[:, self.PAD_TAG_ID] = -10000.0 + # except if the end of sentence is reached + # or we are already in a pad position + self.transitions[self.PAD_TAG_ID, self.EOS_TAG_ID] = 0.0 + self.transitions[self.PAD_TAG_ID, self.PAD_TAG_ID] = 0.0 + + def forward(self, + emissions: paddle.Tensor, + tags: paddle.Tensor, + mask: paddle.Tensor=None) -> paddle.Tensor: + """Compute the negative log-likelihood. See `log_likelihood` method.""" + nll = -self.log_likelihood(emissions, tags, mask=mask) + return nll + + def log_likelihood(self, emissions, tags, mask=None): + """Compute the probability of a sequence of tags given a sequence of + emissions scores. + + Args: + emissions (paddle.Tensor): Sequence of emissions for each label. + Shape of (batch_size, seq_len, nb_labels) if batch_first is True, + (seq_len, batch_size, nb_labels) otherwise. + tags (paddle.LongTensor): Sequence of labels. + Shape of (batch_size, seq_len) if batch_first is True, + (seq_len, batch_size) otherwise. + mask (paddle.FloatTensor, optional): Tensor representing valid positions. + If None, all positions are considered valid. + Shape of (batch_size, seq_len) if batch_first is True, + (seq_len, batch_size) otherwise. + + Returns: + paddle.Tensor: sum of the log-likelihoods for each sequence in the batch. + Shape of () + """ + # fix tensors order by setting batch as the first dimension + if not self.batch_first: + emissions = emissions.transpose(0, 1) + tags = tags.transpose(0, 1) + + if mask is None: + mask = paddle.ones(emissions.shape[:2], dtype=paddle.float) + + scores = self._compute_scores(emissions, tags, mask=mask) + partition = self._compute_log_partition(emissions, mask=mask) + return paddle.sum(scores - partition) + + def decode(self, emissions, mask=None): + """Find the most probable sequence of labels given the emissions using + the Viterbi algorithm. + + Args: + emissions (paddle.Tensor): Sequence of emissions for each label. + Shape (batch_size, seq_len, nb_labels) if batch_first is True, + (seq_len, batch_size, nb_labels) otherwise. + mask (paddle.FloatTensor, optional): Tensor representing valid positions. + If None, all positions are considered valid. + Shape (batch_size, seq_len) if batch_first is True, + (seq_len, batch_size) otherwise. + + Returns: + paddle.Tensor: the viterbi score for the for each batch. + Shape of (batch_size,) + list of lists: the best viterbi sequence of labels for each batch. [B, T] + """ + # fix tensors order by setting batch as the first dimension + if not self.batch_first: + emissions = emissions.transpose(0, 1) + tags = tags.transpose(0, 1) + + if mask is None: + mask = paddle.ones(emissions.shape[:2], dtype=paddle.float) + + scores, sequences = self._viterbi_decode(emissions, mask) + return scores, sequences + + def _compute_scores(self, emissions, tags, mask): + """Compute the scores for a given batch of emissions with their tags. + + Args: + emissions (paddle.Tensor): (batch_size, seq_len, nb_labels) + tags (Paddle.LongTensor): (batch_size, seq_len) + mask (Paddle.FloatTensor): (batch_size, seq_len) + + Returns: + paddle.Tensor: Scores for each batch. + Shape of (batch_size,) + """ + batch_size, seq_length = tags.shape + scores = paddle.zeros([batch_size]) + + # save first and last tags to be used later + first_tags = tags[:, 0] + last_valid_idx = mask.int().sum(1) - 1 + # TODO(Hui Zhang): not support fancy index. + # last_tags = tags.gather(last_valid_idx.unsqueeze(1), axis=1).squeeze() + batch_idx = paddle.arange(batch_size) + gather_last_valid_idx = paddle.to_tensor( + list(zip(batch_idx.tolist(), last_valid_idx.tolist()))) + last_tags = tags.gather_nd(gather_last_valid_idx) + + # add the transition from BOS to the first tags for each batch + # TODO(Hui Zhang): not support fancy index. + # t_scores = self.transitions[self.BOS_TAG_ID, first_tags] + t_scores = self.transitions[self.BOS_TAG_ID].gather(first_tags) + + # add the [unary] emission scores for the first tags for each batch + # for all batches, the first word, see the correspondent emissions + # for the first tags (which is a list of ids): + # emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]] + # TODO(Hui Zhang): not support fancy index. + # e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze() + gather_first_tags_idx = paddle.to_tensor( + list(zip(batch_idx.tolist(), first_tags.tolist()))) + e_scores = emissions[:, 0].gather_nd(gather_first_tags_idx) + + # the scores for a word is just the sum of both scores + scores += e_scores + t_scores + + # now lets do this for each remaining word + for i in range(1, seq_length): + + # we could: iterate over batches, check if we reached a mask symbol + # and stop the iteration, but vecotrizing is faster due to gpu, + # so instead we perform an element-wise multiplication + is_valid = mask[:, i] + + previous_tags = tags[:, i - 1] + current_tags = tags[:, i] + + # calculate emission and transition scores as we did before + # TODO(Hui Zhang): not support fancy index. + # e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze() + gather_current_tags_idx = paddle.to_tensor( + list(zip(batch_idx.tolist(), current_tags.tolist()))) + e_scores = emissions[:, i].gather_nd(gather_current_tags_idx) + # TODO(Hui Zhang): not support fancy index. + # t_scores = self.transitions[previous_tags, current_tags] + gather_transitions_idx = paddle.to_tensor( + list(zip(previous_tags.tolist(), current_tags.tolist()))) + t_scores = self.transitions.gather_nd(gather_transitions_idx) + + # apply the mask + e_scores = e_scores * is_valid + t_scores = t_scores * is_valid + + scores += e_scores + t_scores + + # add the transition from the end tag to the EOS tag for each batch + # scores += self.transitions[last_tags, self.EOS_TAG_ID] + scores += self.transitions.gather(last_tags)[:, self.EOS_TAG_ID] + + return scores + + def _compute_log_partition(self, emissions, mask): + """Compute the partition function in log-space using the forward-algorithm. + + Args: + emissions (paddle.Tensor): (batch_size, seq_len, nb_labels) + mask (Paddle.FloatTensor): (batch_size, seq_len) + + Returns: + paddle.Tensor: the partition scores for each batch. + Shape of (batch_size,) + """ + batch_size, seq_length, nb_labels = emissions.shape + + # in the first iteration, BOS will have all the scores + alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze( + 0) + emissions[:, 0] + + for i in range(1, seq_length): + # (bs, nb_labels) -> (bs, 1, nb_labels) + e_scores = emissions[:, i].unsqueeze(1) + + # (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels) + t_scores = self.transitions.unsqueeze(0) + + # (bs, nb_labels) -> (bs, nb_labels, 1) + a_scores = alphas.unsqueeze(2) + + scores = e_scores + t_scores + a_scores + new_alphas = paddle.logsumexp(scores, axis=1) + + # set alphas if the mask is valid, otherwise keep the current values + is_valid = mask[:, i].unsqueeze(-1) + alphas = is_valid * new_alphas + (1 - is_valid) * alphas + + # add the scores for the final transition + last_transition = self.transitions[:, self.EOS_TAG_ID] + end_scores = alphas + last_transition.unsqueeze(0) + + # return a *log* of sums of exps + return paddle.logsumexp(end_scores, axis=1) + + def _viterbi_decode(self, emissions, mask): + """Compute the viterbi algorithm to find the most probable sequence of labels + given a sequence of emissions. + + Args: + emissions (paddle.Tensor): (batch_size, seq_len, nb_labels) + mask (Paddle.FloatTensor): (batch_size, seq_len) + + Returns: + paddle.Tensor: the viterbi score for the for each batch. + Shape of (batch_size,) + list of lists of ints: the best viterbi sequence of labels for each batch + """ + batch_size, seq_length, nb_labels = emissions.shape + + # in the first iteration, BOS will have all the scores and then, the max + alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze( + 0) + emissions[:, 0] + + backpointers = [] + + for i in range(1, seq_length): + # (bs, nb_labels) -> (bs, 1, nb_labels) + e_scores = emissions[:, i].unsqueeze(1) + + # (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels) + t_scores = self.transitions.unsqueeze(0) + + # (bs, nb_labels) -> (bs, nb_labels, 1) + a_scores = alphas.unsqueeze(2) + + # combine current scores with previous alphas + scores = e_scores + t_scores + a_scores + + # so far is exactly like the forward algorithm, + # but now, instead of calculating the logsumexp, + # we will find the highest score and the tag associated with it + # TODO(Hui Zhang): max not support return score and index. + # max_scores, max_score_tags = paddle.max(scores, axis=1) + max_scores = paddle.max(scores, axis=1) + max_score_tags = paddle.argmax(scores, axis=1) + + # set alphas if the mask is valid, otherwise keep the current values + is_valid = mask[:, i].unsqueeze(-1) + alphas = is_valid * max_scores + (1 - is_valid) * alphas + + # add the max_score_tags for our list of backpointers + # max_scores has shape (batch_size, nb_labels) so we transpose it to + # be compatible with our previous loopy version of viterbi + backpointers.append(max_score_tags.t()) + + # add the scores for the final transition + last_transition = self.transitions[:, self.EOS_TAG_ID] + end_scores = alphas + last_transition.unsqueeze(0) + + # get the final most probable score and the final most probable tag + # TODO(Hui Zhang): max not support return score and index. + # max_final_scores, max_final_tags = paddle.max(end_scores, axis=1) + max_final_scores = paddle.max(end_scores, axis=1) + max_final_tags = paddle.argmax(end_scores, axis=1) + + # find the best sequence of labels for each sample in the batch + best_sequences = [] + emission_lengths = mask.int().sum(axis=1) + for i in range(batch_size): + + # recover the original sentence length for the i-th sample in the batch + sample_length = emission_lengths[i].item() + + # recover the max tag for the last timestep + sample_final_tag = max_final_tags[i].item() + + # limit the backpointers until the last but one + # since the last corresponds to the sample_final_tag + sample_backpointers = backpointers[:sample_length - 1] + + # follow the backpointers to build the sequence of labels + sample_path = self._find_best_path(i, sample_final_tag, + sample_backpointers) + + # add this path to the list of best sequences + best_sequences.append(sample_path) + + return max_final_scores, best_sequences + + def _find_best_path(self, sample_id, best_tag, backpointers): + """Auxiliary function to find the best path sequence for a specific sample. + + Args: + sample_id (int): sample index in the range [0, batch_size) + best_tag (int): tag which maximizes the final score + backpointers (list of lists of tensors): list of pointers with + shape (seq_len_i-1, nb_labels, batch_size) where seq_len_i + represents the length of the ith sample in the batch + + Returns: + list of ints: a list of tag indexes representing the bast path + """ + # add the final best_tag to our best path + best_path = [best_tag] + + # traverse the backpointers in backwards + for backpointers_t in reversed(backpointers): + + # recover the best_tag at this timestep + best_tag = backpointers_t[best_tag][sample_id].item() + + # append to the beginning of the list so we don't need to reverse it later + best_path.insert(0, best_tag) + + return best_path diff --git a/examples/chinese_g2p/local/ignore_sandhi.py b/examples/chinese_g2p/local/ignore_sandhi.py index cda1bd14..b7f37a27 100644 --- a/examples/chinese_g2p/local/ignore_sandhi.py +++ b/examples/chinese_g2p/local/ignore_sandhi.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse -from typing import List, Union from pathlib import Path +from typing import List +from typing import Union def erized(syllable: str) -> bool: @@ -67,7 +68,9 @@ def ignore_sandhi(reference: List[str], generated: List[str]) -> List[str]: return result -def convert_transcriptions(reference: Union[str, Path], generated: Union[str, Path], output: Union[str, Path]): +def convert_transcriptions(reference: Union[str, Path], + generated: Union[str, Path], + output: Union[str, Path]): with open(reference, 'rt') as f_ref: with open(generated, 'rt') as f_gen: with open(output, 'wt') as f_out: diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index 09f700da..393dd457 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -3,7 +3,7 @@ ## Deepspeech2 | Model | release | Config | Test set | Loss | WER | -| --- | --- | --- | --- | --- | --- | +| --- | --- | --- | --- | --- | --- | | DeepSpeech2 | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 | | DeepSpeech2 | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 | | DeepSpeech2 | 1.8.5 | - | test-clean | - | 0.074939 |