From e5347c48efdc00e3cf2fb5fe6ed22118bcd29298 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 14 Apr 2021 03:00:15 +0000 Subject: [PATCH] test can run --- .notebook/python_test.ipynb | 44 +++++- deepspeech/__init__.py | 23 ++- deepspeech/exps/u2/model.py | 84 +++++++---- .../frontend/featurizer/speech_featurizer.py | 30 +++- .../frontend/featurizer/text_featurizer.py | 113 +++++++++++++-- deepspeech/io/dataset.py | 18 ++- deepspeech/models/u2.py | 135 ++++++++++++++++-- deepspeech/modules/ctc.py | 23 ++- deepspeech/modules/decoder.py | 2 +- deepspeech/modules/decoder_layer.py | 13 +- deepspeech/modules/mask.py | 2 + deepspeech/training/cli.py | 3 + examples/tiny/s1/conf/conformer.yaml | 16 ++- examples/tiny/s1/local/test.sh | 1 + tests/mask_test.py | 19 ++- 15 files changed, 440 insertions(+), 86 deletions(-) diff --git a/.notebook/python_test.ipynb b/.notebook/python_test.ipynb index 9874f5d00..50d5a8331 100644 --- a/.notebook/python_test.ipynb +++ b/.notebook/python_test.ipynb @@ -617,10 +617,52 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 58, "id": "pregnant-modem", "metadata": {}, "outputs": [], + "source": [ + "def get_default_args(fn):\n", + " if fn is None:\n", + " return {}\n", + "\n", + " signature = inspect.signature(fn)\n", + " return {\n", + " k: v.default\n", + " for k, v in signature.parameters.items()\n", + " if v.default is not inspect.Parameter.empty\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "light-drill", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'inspect' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mget_default_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36mget_default_args\u001b[0;34m(fn)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0msignature\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minspect\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msignature\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m return {\n\u001b[1;32m 7\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdefault\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'inspect' is not defined" + ] + } + ], + "source": [ + "get_default_args(io.open)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "protective-belgium", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index c6c2e607a..176ae4f96 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -179,6 +179,11 @@ if not hasattr(paddle.Tensor, 'eq'): ) paddle.Tensor.eq = eq +if not hasattr(paddle, 'eq'): + logger.warn( + "override eq of paddle if exists or register, remove this when fixed!") + paddle.eq = eq + def contiguous(xs: paddle.Tensor) -> paddle.Tensor: return xs @@ -256,13 +261,14 @@ if not hasattr(paddle.Tensor, 'masked_fill'): def masked_fill_(xs: paddle.Tensor, mask: paddle.Tensor, - value: Union[float, int]): + value: Union[float, int]) -> paddle.Tensor: assert is_broadcastable(xs.shape, mask.shape) is True bshape = paddle.broadcast_shape(xs.shape, mask.shape) mask = mask.broadcast_to(bshape) trues = paddle.ones_like(xs) * value ret = paddle.where(mask, trues, xs) paddle.assign(ret.detach(), output=xs) + return xs if not hasattr(paddle.Tensor, 'masked_fill_'): @@ -271,9 +277,10 @@ if not hasattr(paddle.Tensor, 'masked_fill_'): paddle.Tensor.masked_fill_ = masked_fill_ -def fill_(xs: paddle.Tensor, value: Union[float, int]): +def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor: val = paddle.full_like(xs, value) paddle.assign(val.detach(), output=xs) + return xs if not hasattr(paddle.Tensor, 'fill_'): @@ -317,7 +324,7 @@ if not hasattr(paddle.Tensor, 'type_as'): def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: assert len(args) == 1 - if isinstace(args[0], str): # dtype + if isinstance(args[0], str): # dtype return x.astype(args[0]) elif isinstance(args[0], paddle.Tensor): #Tensor return x.astype(args[0].dtype) @@ -338,6 +345,16 @@ if not hasattr(paddle.Tensor, 'float'): logger.warn("register user float to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'float', func_float) + +def tolist(x: paddle.Tensor) -> List[Any]: + return x.numpy().tolist() + + +if not hasattr(paddle.Tensor, 'tolist'): + logger.warn( + "register user tolist to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'tolist', tolist) + ########### hcak paddle.nn.functional ############# diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 6221b9318..b5f90f046 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -13,6 +13,7 @@ # limitations under the License. """Contains U2 model.""" +import sys import time import logging import numpy as np @@ -256,11 +257,19 @@ class U2Tester(U2Trainer): cutoff_prob=1.0, # Cutoff probability for pruning. cutoff_top_n=40, # Cutoff number for pruning. lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. - decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy + decoding_method='attention', # Decoding method. Options: 'attention', 'ctc_greedy_search', + # 'ctc_prefix_beam_search', 'attention_rescoring' error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' num_proc_bsearch=8, # # of CPUs for beam search. - beam_size=500, # Beam search width. - batch_size=128, # decoding batch size + beam_size=10, # Beam search width. + batch_size=16, # decoding batch size + ctc_weight=0.0, # ctc weight for attention rescoring decode mode. + decoding_chunk_size=-1, # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1. + simulate_streaming=False, # simulate streaming inference. Defaults to False. )) if config is not None: @@ -279,19 +288,19 @@ class U2Tester(U2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, audio, texts, audio_len, texts_len): + def compute_metrics(self, audio, audio_len, texts, texts_len, fout=None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer - vocab_list = self.test_loader.dataset.vocab_list + text_feature = self.test_loader.dataset.text_feature target_transcripts = self.ordid2token(texts, texts_len) result_transcripts = self.model.decode( audio, audio_len, - vocab_list, + text_feature=text_feature, decoding_method=cfg.decoding_method, lang_model_path=cfg.lang_model_path, beam_alpha=cfg.alpha, @@ -299,13 +308,19 @@ class U2Tester(U2Trainer): beam_size=cfg.beam_size, cutoff_prob=cfg.cutoff_prob, cutoff_top_n=cfg.cutoff_top_n, - num_processes=cfg.num_proc_bsearch) + num_processes=cfg.num_proc_bsearch, + ctc_weight=cfg.ctc_weight, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) for target, result in zip(target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref num_ins += 1 + if fout: + fout.write(result + "\n") self.logger.info( "\nTarget Transcription: %s\nOutput Transcription: %s" % (target, result)) @@ -322,6 +337,7 @@ class U2Tester(U2Trainer): @mp_tools.rank_zero_only @paddle.no_grad() def test(self): + assert self.args.result_file self.model.eval() self.logger.info( f"Test Total Examples: {len(self.test_loader.dataset)}") @@ -329,14 +345,16 @@ class U2Tester(U2Trainer): error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 - for i, batch in enumerate(self.test_loader): - metrics = self.compute_metrics(*batch) - errors_sum += metrics['errors_sum'] - len_refs += metrics['len_refs'] - num_ins += metrics['num_ins'] - error_rate_type = metrics['error_rate_type'] - self.logger.info("Error rate [%s] (%d/?) = %f" % - (error_rate_type, num_ins, errors_sum / len_refs)) + with open(self.args.result_file, 'w') as fout: + for i, batch in enumerate(self.test_loader): + metrics = self.compute_metrics(*batch, fout=fout) + errors_sum += metrics['errors_sum'] + len_refs += metrics['len_refs'] + num_ins += metrics['num_ins'] + error_rate_type = metrics['error_rate_type'] + self.logger.info( + "Error rate [%s] (%d/?) = %f" % + (error_rate_type, num_ins, errors_sum / len_refs)) # logging msg = "Test: " @@ -351,24 +369,34 @@ class U2Tester(U2Trainer): try: self.test() except KeyboardInterrupt: - exit(-1) + sys.exit(-1) - def export(self): + def load_inferspec(self): + """infer model and input spec. + + Returns: + nn.Layer: inference model + List[paddle.static.InputSpec]: input spec. + """ from deepspeech.models.u2 import U2InferModel infer_model = U2InferModel.from_pretrained(self.test_loader.dataset, self.config.model.clone(), self.args.checkpoint_path) - infer_model.eval() feat_dim = self.test_loader.dataset.feature_size - static_model = paddle.jit.to_static( - infer_model, - input_spec=[ - paddle.static.InputSpec( - shape=[None, feat_dim, None], - dtype='float32'), # audio, [B,D,T] - paddle.static.InputSpec(shape=[None], - dtype='int64'), # audio_length, [B] - ]) + input_spec = [ + paddle.static.InputSpec( + shape=[None, feat_dim, None], + dtype='float32'), # audio, [B,D,T] + paddle.static.InputSpec(shape=[None], + dtype='int64'), # audio_length, [B] + ] + return infer_model, input_spec + + def export(self): + infer_model, input_spec = self.load_inferspec() + assert isinstance(input_spec, list), type(input_spec) + infer_model.eval() + static_model = paddle.jit.to_static(infer_model, input_spec=input_spec) logger.info(f"Export code: {static_model.forward.code}") paddle.jit.save(static_model, self.args.export_path) @@ -376,7 +404,7 @@ class U2Tester(U2Trainer): try: self.export() except KeyboardInterrupt: - exit(-1) + sys.exit(-1) def setup(self): """Setup the experiment. diff --git a/deepspeech/frontend/featurizer/speech_featurizer.py b/deepspeech/frontend/featurizer/speech_featurizer.py index 6530fc937..920bec538 100644 --- a/deepspeech/frontend/featurizer/speech_featurizer.py +++ b/deepspeech/frontend/featurizer/speech_featurizer.py @@ -107,8 +107,8 @@ class SpeechFeaturizer(object): def vocab_size(self): """Return the vocabulary size. - :return: Vocabulary size. - :rtype: int + Returns: + int: Vocabulary size. """ return self._text_featurizer.vocab_size @@ -116,16 +116,34 @@ class SpeechFeaturizer(object): def vocab_list(self): """Return the vocabulary in list. - :return: Vocabulary in list. - :rtype: list + Returns: + List[str]: """ return self._text_featurizer.vocab_list + @property + def vocab_dict(self): + """Return the vocabulary in dict. + + Returns: + Dict[str, int]: + """ + return self._text_featurizer.vocab_dict + @property def feature_size(self): """Return the audio feature size. - :return: audio feature size. - :rtype: int + Returns: + int: audio feature size. """ return self._audio_featurizer.feature_size + + @property + def text_feature(self): + """Return the text feature object. + + Returns: + TextFeaturizer: object. + """ + return self._text_featurizer diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index e1d34e5ac..d70f88f44 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -16,6 +16,7 @@ import sentencepiece as spm from deepspeech.frontend.utility import UNK +from deepspeech.frontend.utility import EOS class TextFeaturizer(object): @@ -32,10 +33,12 @@ class TextFeaturizer(object): spm_model_prefix (str, optional): spm model prefix. Defaults to None. """ assert unit_type in ('char', 'spm', 'word') - self.unk = UNK self.unit_type = unit_type - self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( + self._vocab_dict, self._id2token, self._vocab_list = self._load_vocabulary_from_file( vocab_filepath) + self.unk = UNK + self.unk_id = self._vocab_list.index(self.unk) + self.eos_id = self._vocab_list.index(EOS) if unit_type == 'spm': spm_model = spm_model_prefix + '.model' @@ -51,14 +54,23 @@ class TextFeaturizer(object): tokens = self.spm_tokenize(text) return tokens + def detokenize(self, tokens): + if self.unit_type == 'char': + text = self.char_detokenize(tokens) + elif self.unit_type == 'word': + text = self.word_detokenize(tokens) + else: # spm + text = self.spm_detokenize(tokens) + return text + def featurize(self, text): - """Convert text string to a list of token indices in char-level.Note - that the token indexing order follows the given vocabulary file. + """Convert text string to a list of token indices. - :param text: Text to process. - :type text: str - :return: List of char-level token indices. - :rtype: List[int] + Args: + text (str): Text to process. + + Returns: + List[int]: List of token indices. """ tokens = self.tokenize(text) ids = [] @@ -67,6 +79,24 @@ class TextFeaturizer(object): ids.append(self._vocab_dict[token]) return ids + def defeaturize(self, idxs): + """Convert a list of token indices to text string, + ignore index after eos_id. + + Args: + idxs (List[int]): List of token indices. + + Returns: + str: Text to process. + """ + tokens = [] + for idx in idxs: + if idx == self.eos_id: + break + tokens.append(self._id2token[idx]) + text = self.detokenize(tokens) + return text + @property def vocab_size(self): """Return the vocabulary size. @@ -80,19 +110,50 @@ class TextFeaturizer(object): def vocab_list(self): """Return the vocabulary in list. - :return: Vocabulary in list. - :rtype: list + Returns: + List[str]: tokens. """ return self._vocab_list + @property + def vocab_dict(self): + """Return the vocabulary in dict. + + Returns: + Dict[str, int]: token str -> int + """ + return self._vocab_dict + def char_tokenize(self, text): - """Character tokenizer.""" + """Character tokenizer. + + Args: + text (str): text string. + + Returns: + List[str]: tokens. + """ return list(text.strip()) + def char_detokenize(self, tokens): + """Character detokenizer. + + Args: + tokens (List[str]): tokens. + + Returns: + str: text string. + """ + return "".join(tokens) + def word_tokenize(self, text): - """Word tokenizer, spearte by .""" + """Word tokenizer, separate by .""" return text.strip().split() + def word_detokenize(self, tokens): + """Word detokenizer, separate by .""" + return " ".join(tokens) + def spm_tokenize(self, text): """spm tokenize. @@ -125,12 +186,34 @@ class TextFeaturizer(object): enc_line = encode_line(text) return enc_line + def spm_detokenize(self, tokens, input_format='piece'): + """spm detokenize. + + Args: + ids (List[str]): tokens. + + Returns: + str: text + """ + if input_format == "piece": + + def decode(l): + return "".join(self.sp.DecodePieces(l)) + elif input_format == "id": + + def decode(l): + return "".join(self.sp.DecodeIds(l)) + + return decode(tokens) + def _load_vocabulary_from_file(self, vocab_filepath): """Load vocabulary from file.""" vocab_lines = [] with open(vocab_filepath, 'r', encoding='utf-8') as file: vocab_lines.extend(file.readlines()) vocab_list = [line[:-1] for line in vocab_lines] - vocab_dict = dict( - [(token, id) for (id, token) in enumerate(vocab_list)]) - return vocab_dict, vocab_list + id2token = dict( + [(idx, token) for (idx, token) in enumerate(vocab_list)]) + token2id = dict( + [(token, idx) for (idx, token) in enumerate(vocab_list)]) + return token2id, id2token, vocab_list diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 4875929eb..4550b058a 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -233,22 +233,20 @@ class ManifestDataset(Dataset): @property def vocab_size(self): - """Return the vocabulary size. - - :return: Vocabulary size. - :rtype: int - """ return self._speech_featurizer.vocab_size @property def vocab_list(self): - """Return the vocabulary in list. - - :return: Vocabulary in list. - :rtype: list - """ return self._speech_featurizer.vocab_list + @property + def vocab_dict(self): + return self._speech_featurizer.vocab_dict + + @property + def text_feature(self): + return self._speech_featurizer.text_feature + @property def feature_size(self): return self._speech_featurizer.feature_size diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 95c13d402..54a993998 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -16,10 +16,11 @@ Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recogni (https://arxiv.org/pdf/2012.05481.pdf) """ +import sys from collections import defaultdict import logging from yacs.config import CfgNode -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict import paddle from paddle import jit @@ -132,6 +133,7 @@ class U2BaseModel(nn.Module): smoothing=lsm_weight, normalize_length=length_normalized_loss, ) + @jit.export def forward( self, speech: paddle.Tensor, @@ -158,7 +160,7 @@ class U2BaseModel(nn.Module): encoder_out, encoder_mask = self.encoder(speech, speech_lengths) #TODO(Hui Zhang): sum not support bool type #encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B] - encoder_out_lens = encoder_mask.squeeze(1).astype(paddle.int64).sum( + encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum( 1) #[B, 1, T] -> [B] # 2a. Attention-decoder branch @@ -301,14 +303,15 @@ class U2BaseModel(nn.Module): # log scale score scores = paddle.to_tensor( [0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float) - scores = scores.to(device).repeat([batch_size]).unsqueeze(1).to( + scores = scores.to(device).repeat(batch_size).unsqueeze(1).to( device) # (B*N, 1) end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1) cache: Optional[List[paddle.Tensor]] = None # 2. Decoder forward step by step for i in range(1, maxlen + 1): # Stop if all batch and all beam produce eos - if end_flag.sum() == running_size: + # TODO(Hui Zhang): if end_flag.sum() == running_size: + if end_flag.cast(paddle.int64).sum() == running_size: break # 2.1 Forward decoder step @@ -333,7 +336,7 @@ class U2BaseModel(nn.Module): # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), # then find offset_k_index in top_k_index base_k_index = paddle.arange(batch_size).view(-1, 1).repeat( - [1, beam_size]) # (B, N) + 1, beam_size) # (B, N) base_k_index = base_k_index * beam_size * beam_size best_k_index = base_k_index.view(-1) + offset_k_index.view( -1) # (B*N) @@ -678,6 +681,108 @@ class U2BaseModel(nn.Module): decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1) return decoder_out + @paddle.no_grad() + def decode(self, + feats: paddle.Tensor, + feats_lengths: paddle.Tensor, + text_feature: Dict[str, int], + decoding_method: str, + lang_model_path: str, + beam_alpha: float, + beam_beta: float, + beam_size: int, + cutoff_prob: float, + cutoff_top_n: int, + num_processes: int, + ctc_weight: float=0.0, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + simulate_streaming: bool=False): + """u2 decoding. + + Args: + feats (Tenosr): audio features, (B, T, D) + feats_lengths (Tenosr): (B) + text_feature (TextFeaturizer): text feature object. + decoding_method (str): decoding mode, e.g. + 'attention', 'ctc_greedy_search', + 'ctc_prefix_beam_search', 'attention_rescoring' + lang_model_path (str): lm path. + beam_alpha (float): lm weight. + beam_beta (float): length penalty. + beam_size (int): beam size for search + cutoff_prob (float): for prune. + cutoff_top_n (int): for prune. + num_processes (int): + ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0. + decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + 0: used for training, it's prohibited here. + num_decoding_left_chunks (int, optional): + number of left chunks for decoding. Defaults to -1. + simulate_streaming (bool, optional): simulate streaming inference. Defaults to False. + + Raises: + ValueError: when not support decoding_method. + + Returns: + List[List[int]]: transcripts. + """ + batch_size = feats.size(0) + if decoding_method in ['ctc_prefix_beam_search', + 'attention_rescoring'] and batch_size > 1: + logger.fatal( + f'decoding mode {decoding_method} must be running with batch_size == 1' + ) + sys.exit(1) + + if decoding_method == 'attention': + hyps = self.recognize( + feats, + feats_lengths, + beam_size=beam_size, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) + hyps = [hyp.tolist() for hyp in hyps] + elif decoding_method == 'ctc_greedy_search': + hyps = self.ctc_greedy_search( + feats, + feats_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) + # ctc_prefix_beam_search and attention_rescoring only return one + # result in List[int], change it to List[List[int]] for compatible + # with other batch decoding mode + elif decoding_method == 'ctc_prefix_beam_search': + assert feats.size(0) == 1 + hyp = self.ctc_prefix_beam_search( + feats, + feats_lengths, + beam_size, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) + hyps = [hyp] + elif decoding_method == 'attention_rescoring': + assert feats.size(0) == 1 + hyp = self.attention_rescoring( + feats, + feats_lengths, + beam_size, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + ctc_weight=ctc_weight, + simulate_streaming=simulate_streaming) + hyps = [hyp] + else: + raise ValueError(f"Not support decoding method: {decoding_method}") + + res = [text_feature.defeaturize(hyp) for hyp in hyps] + return res + class U2Model(U2BaseModel): def __init__(self, configs: dict): @@ -779,14 +884,24 @@ class U2InferModel(U2Model): def __init__(self, configs: dict): super().__init__(configs) - def forward(self, audio, audio_len): + def forward(self, + feats, + feats_lengths, + decoding_chunk_size=-1, + num_decoding_left_chunks=-1, + simulate_streaming=False): """export model function Args: - audio (Tensor): [B, T, D] - audio_len (Tensor): [B] + feats (Tensor): [B, T, D] + feats_lengths (Tensor): [B] Returns: - probs: probs after softmax + List[List[int]]: best path result """ - raise NotImplementedError("U2Model infer") + return self.ctc_greedy_search( + feats, + feats_lengths, + decoding_chunk_size=decoding_chunk_size, + num_decoding_left_chunks=num_decoding_left_chunks, + simulate_streaming=simulate_streaming) diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index be283165c..64508a74d 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -224,9 +224,28 @@ class CTCDecoder(nn.Layer): def decode_probs(self, probs, logits_lens, vocab_list, decoding_method, lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n, num_processes): - """ probs: activation after softmax - logits_len: audio output lens + """ctc decoding with probs. + + Args: + probs (Tenosr): activation after softmax + logits_lens (Tenosr): audio output lens + vocab_list ([type]): [description] + decoding_method ([type]): [description] + lang_model_path ([type]): [description] + beam_alpha ([type]): [description] + beam_beta ([type]): [description] + beam_size ([type]): [description] + cutoff_prob ([type]): [description] + cutoff_top_n ([type]): [description] + num_processes ([type]): [description] + + Raises: + ValueError: when decoding_method not support. + + Returns: + List[str]: transcripts. """ + probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)] if decoding_method == "ctc_greedy": result_transcripts = self._decode_batch_greedy( diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py index 3e52bc7ab..11f18f7bf 100644 --- a/deepspeech/modules/decoder.py +++ b/deepspeech/modules/decoder.py @@ -176,5 +176,5 @@ class TransformerDecoder(nn.Module): else: y = x[:, -1] if self.use_output_layer: - y = paddle.log_softmax(self.output_layer(y), dim=-1) + y = paddle.log_softmax(self.output_layer(y), axis=-1) return y, new_cache diff --git a/deepspeech/modules/decoder_layer.py b/deepspeech/modules/decoder_layer.py index 64e16b75a..a781ea5d5 100644 --- a/deepspeech/modules/decoder_layer.py +++ b/deepspeech/modules/decoder_layer.py @@ -101,12 +101,17 @@ class DecoderLayer(nn.Module): tgt_q_mask = tgt_mask else: # compute only the last frame query keeping dim: max_time_out -> 1 - assert cache.shape == ( - tgt.shape[0], tgt.shape[1] - 1, self.size, - ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + assert cache.shape == [ + tgt.shape[0], + tgt.shape[1] - 1, + self.size, + ], f"{cache.shape} == {[tgt.shape[0], tgt.shape[1] - 1, self.size]}" tgt_q = tgt[:, -1:, :] residual = residual[:, -1:, :] - tgt_q_mask = tgt_mask[:, -1:, :] + # TODO(Hui Zhang): slice not support bool type + # tgt_q_mask = tgt_mask[:, -1:, :] + tgt_q_mask = tgt_mask.cast(paddle.int64)[:, -1:, :].cast( + paddle.bool) if self.concat_after: tgt_concat = paddle.cat( diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index d9430a269..beb67db7d 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -41,6 +41,7 @@ def sequence_mask(x_len, max_len=None, dtype='float32'): [[1., 1., 0., 0.], [1., 1., 1., 1.]] """ + assert x_len.dim() == 1 max_len = max_len or x_len.max() x_len = paddle.unsqueeze(x_len, -1) row_vector = paddle.arange(max_len) @@ -65,6 +66,7 @@ def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: [0, 0, 0, 1, 1], [0, 0, 1, 1, 1]] """ + assert lengths.dim() == 1 batch_size = int(lengths.shape[0]) max_len = int(lengths.max()) seq_range = paddle.arange(0, max_len, dtype=paddle.int64) diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index e36c9264d..d06672825 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -57,6 +57,9 @@ def default_argument_parser(): # save jit model to parser.add_argument("--export_path", type=str, help="path of the jit model to save") + # save asr result to + parser.add_argument("--result_file", type=str, help="path of save the asr result") + # running parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], help="device type to use, cpu and gpu are supported.") diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index af26f0291..7d4303660 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -126,16 +126,24 @@ training: lr_decay: 1.0 log_interval: 100 + decoding: - batch_size: 128 + batch_size: 16 error_rate_type: wer - decoding_method: ctc_beam_search + decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm alpha: 2.5 beta: 0.3 - beam_size: 500 + beam_size: 10 cutoff_prob: 1.0 - cutoff_top_n: 40 + cutoff_top_n: 0 num_proc_bsearch: 8 + ctc_weight: 0.0 # ctc weight for attention rescoring decode mode. + decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. + simulate_streaming: False # simulate streaming inference. Defaults to False. diff --git a/examples/tiny/s1/local/test.sh b/examples/tiny/s1/local/test.sh index 6dbb1c388..475e941e7 100644 --- a/examples/tiny/s1/local/test.sh +++ b/examples/tiny/s1/local/test.sh @@ -11,6 +11,7 @@ python3 -u ${BIN_DIR}/test.py \ --device 'gpu' \ --nproc 1 \ --config conf/conformer.yaml \ +--result_file data/asr.result \ --output ckpt if [ $? -ne 0 ]; then diff --git a/tests/mask_test.py b/tests/mask_test.py index f5e9cd7cb..afdd53ba8 100644 --- a/tests/mask_test.py +++ b/tests/mask_test.py @@ -17,14 +17,23 @@ import numpy as np import unittest from deepspeech.modules.mask import sequence_mask from deepspeech.modules.mask import make_non_pad_mask +from deepspeech.modules.mask import make_pad_mask class TestU2Model(unittest.TestCase): def setUp(self): paddle.set_device('cpu') self.lengths = paddle.to_tensor([5, 3, 2]) - self.masks = np.array( - [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]], ) + self.masks = np.array([ + [1, 1, 1, 1, 1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0], + ]) + self.pad_masks = np.array([ + [0, 0, 0, 0, 0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1], + ]) def test_sequence_mask(self): res = sequence_mask(self.lengths) @@ -32,7 +41,13 @@ class TestU2Model(unittest.TestCase): def test_make_non_pad_mask(self): res = make_non_pad_mask(self.lengths) + res1 = sequence_mask(self.lengths) self.assertSequenceEqual(res.numpy().tolist(), self.masks.tolist()) + self.assertSequenceEqual(res.numpy().tolist(), res1.numpy().tolist()) + + def test_make_pad_mask(self): + res = make_pad_mask(self.lengths) + self.assertSequenceEqual(res.numpy().tolist(), self.pad_masks.tolist()) if __name__ == '__main__':