diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index 6ed1177a..ed209f3d 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -233,7 +233,8 @@ def is_broadcastable(shp1, shp2): def masked_fill(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): - assert is_broadcastable(xs.shape, mask.shape) is True + assert is_broadcastable(xs.shape, mask.shape) is True, (xs.shape, + mask.shape) bshape = paddle.broadcast_shape(xs.shape, mask.shape) mask = mask.broadcast_to(bshape) trues = paddle.ones_like(xs) * value @@ -315,7 +316,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: assert len(args) == 1 if isinstance(args[0], str): # dtype return x.astype(args[0]) - elif isinstance(args[0], paddle.Tensor): #Tensor + elif isinstance(args[0], paddle.Tensor): # Tensor return x.astype(args[0].dtype) else: # Device return x @@ -364,6 +365,7 @@ from typing import Tuple from typing import Iterator from collections import OrderedDict, abc as container_abcs + class LayerDict(paddle.nn.Layer): r"""Holds submodules in a dictionary. @@ -408,7 +410,7 @@ class LayerDict(paddle.nn.Layer): return x """ - def __init__(self, modules: Optional[Mapping[str, Layer]] = None) -> None: + def __init__(self, modules: Optional[Mapping[str, Layer]]=None) -> None: super(LayerDict, self).__init__() if modules is not None: self.update(modules) @@ -475,10 +477,11 @@ class LayerDict(paddle.nn.Layer): """ if not isinstance(modules, container_abcs.Iterable): raise TypeError("LayerDict.update should be called with an " - "iterable of key/value pairs, but got " + - type(modules).__name__) + "iterable of key/value pairs, but got " + type( + modules).__name__) - if isinstance(modules, (OrderedDict, LayerDict, container_abcs.Mapping)): + if isinstance(modules, + (OrderedDict, LayerDict, container_abcs.Mapping)): for key, module in modules.items(): self[key] = module else: @@ -490,14 +493,15 @@ class LayerDict(paddle.nn.Layer): type(m).__name__) if not len(m) == 2: raise ValueError("LayerDict update sequence element " - "#" + str(j) + " has length " + str(len(m)) + - "; 2 is required") + "#" + str(j) + " has length " + str( + len(m)) + "; 2 is required") # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] # that's too cumbersome to type correctly with overloads, so we add an ignore here self[m[0]] = m[1] # type: ignore[assignment] # remove forward alltogether to fallback on Module's _forward_unimplemented + if not hasattr(paddle.nn, 'LayerDict'): logger.debug( "register user LayerDict to paddle.nn, remove this when fixed!") diff --git a/deepspeech/decoders/beam_search/__init__.py b/deepspeech/decoders/beam_search/__init__.py new file mode 100644 index 00000000..79a1e9d3 --- /dev/null +++ b/deepspeech/decoders/beam_search/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .batch_beam_search import BatchBeamSearch +from .beam_search import beam_search +from .beam_search import BeamSearch +from .beam_search import Hypothesis diff --git a/deepspeech/decoders/beam_search/batch_beam_search.py b/deepspeech/decoders/beam_search/batch_beam_search.py new file mode 100644 index 00000000..3fc1c435 --- /dev/null +++ b/deepspeech/decoders/beam_search/batch_beam_search.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class BatchBeamSearch(): + pass diff --git a/deepspeech/decoders/beam_search.py b/deepspeech/decoders/beam_search/beam_search.py similarity index 72% rename from deepspeech/decoders/beam_search.py rename to deepspeech/decoders/beam_search/beam_search.py index 5bf0d3e2..8fd8f9b8 100644 --- a/deepspeech/decoders/beam_search.py +++ b/deepspeech/decoders/beam_search/beam_search.py @@ -1,7 +1,18 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Beam search module.""" - from itertools import chain -import logger from typing import Any from typing import Dict from typing import List @@ -11,18 +22,18 @@ from typing import Union import paddle -from .utils import end_detect -from .scorers.scorer_interface import PartialScorerInterface -from .scorers.scorer_interface import ScorerInterface - +from ..scorers.scorer_interface import PartialScorerInterface +from ..scorers.scorer_interface import ScorerInterface +from ..utils import end_detect from deepspeech.utils.log import Log logger = Log(__name__).getlog() + class Hypothesis(NamedTuple): """Hypothesis data type.""" - yseq: paddle.Tensor # (T,) + yseq: paddle.Tensor # (T,) score: Union[float, paddle.Tensor] = 0 scores: Dict[str, Union[float, paddle.Tensor]] = dict() states: Dict[str, Any] = dict() @@ -32,25 +43,24 @@ class Hypothesis(NamedTuple): return self._replace( yseq=self.yseq.tolist(), score=float(self.score), - scores={k: float(v) for k, v in self.scores.items()}, - )._asdict() + scores={k: float(v) + for k, v in self.scores.items()}, )._asdict() class BeamSearch(paddle.nn.Layer): """Beam search implementation.""" def __init__( - self, - scorers: Dict[str, ScorerInterface], - weights: Dict[str, float], - beam_size: int, - vocab_size: int, - sos: int, - eos: int, - token_list: List[str] = None, - pre_beam_ratio: float = 1.5, - pre_beam_score_key: str = None, - ): + self, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + beam_size: int, + vocab_size: int, + sos: int, + eos: int, + token_list: List[str]=None, + pre_beam_ratio: float=1.5, + pre_beam_score_key: str=None, ): """Initialize beam search. Args: @@ -72,12 +82,12 @@ class BeamSearch(paddle.nn.Layer): super().__init__() # set scorers self.weights = weights - self.scorers = dict() # all = full + partial - self.full_scorers = dict() # full tokens - self.part_scorers = dict() # partial tokens + self.scorers = dict() # all = full + partial + self.full_scorers = dict() # full tokens + self.part_scorers = dict() # partial tokens # this module dict is required for recursive cast # `self.to(device, dtype)` in `recog.py` - self.nn_dict = paddle.nn.LayerDict() # nn.Layer + self.nn_dict = paddle.nn.LayerDict() # nn.Layer for k, v in scorers.items(): w = weights.get(k, 0) if w == 0 or v is None: @@ -101,20 +111,16 @@ class BeamSearch(paddle.nn.Layer): self.pre_beam_size = int(pre_beam_ratio * beam_size) self.beam_size = beam_size self.n_vocab = vocab_size - if ( - pre_beam_score_key is not None - and pre_beam_score_key != "full" - and pre_beam_score_key not in self.full_scorers - ): - raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") + if (pre_beam_score_key is not None and pre_beam_score_key != "full" and + pre_beam_score_key not in self.full_scorers): + raise KeyError( + f"{pre_beam_score_key} is not found in {self.full_scorers}") # selected `key` scorer to do pre beam search self.pre_beam_score_key = pre_beam_score_key # do_pre_beam when need, valid and has part_scorers - self.do_pre_beam = ( - self.pre_beam_score_key is not None - and self.pre_beam_size < self.n_vocab - and len(self.part_scorers) > 0 - ) + self.do_pre_beam = (self.pre_beam_score_key is not None and + self.pre_beam_size < self.n_vocab and + len(self.part_scorers) > 0) def init_hyp(self, x: paddle.Tensor) -> List[Hypothesis]: """Get an initial hypothesis data. @@ -136,12 +142,12 @@ class BeamSearch(paddle.nn.Layer): yseq=paddle.to_tensor([self.sos], place=x.place), score=0.0, scores=init_scores, - states=init_states, - ) + states=init_states, ) ] @staticmethod - def append_token(xs: paddle.Tensor, x: int) -> paddle.Tensor: + def append_token(xs: paddle.Tensor, + x: Union[int, paddle.Tensor]) -> paddle.Tensor: """Append new token to prefix tokens. Args: @@ -152,12 +158,11 @@ class BeamSearch(paddle.nn.Layer): paddle.Tensor: (T+1,), New tensor contains: xs + [x] with xs.dtype and xs.device """ - x = paddle.to_tensor([x], dtype=xs.dtype, place=xs.place) - return paddle.cat((xs, x)) + x = paddle.to_tensor([x], dtype=xs.dtype) if isinstance(x, int) else x + return paddle.concat((xs, x)) - def score_full( - self, hyp: Hypothesis, x: paddle.Tensor - ) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: + def score_full(self, hyp: Hypothesis, x: paddle.Tensor + ) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: """Score new hypothesis by `self.full_scorers`. Args: @@ -179,9 +184,11 @@ class BeamSearch(paddle.nn.Layer): scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x) return scores, states - def score_partial( - self, hyp: Hypothesis, ids: paddle.Tensor, x: paddle.Tensor - ) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: + def score_partial(self, + hyp: Hypothesis, + ids: paddle.Tensor, + x: paddle.Tensor + ) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: """Score new hypothesis by `self.part_scorers`. Args: @@ -202,12 +209,12 @@ class BeamSearch(paddle.nn.Layer): states = dict() for k, d in self.part_scorers.items(): # scores[k] shape (len(ids),) - scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) + scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], + x) return scores, states - def beam( - self, weighted_scores: paddle.Tensor, ids: paddle.Tensor - ) -> Tuple[paddle.Tensor, paddle.Tensor]: + def beam(self, weighted_scores: paddle.Tensor, + ids: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute topk full token ids and partial token ids. Args: @@ -224,7 +231,8 @@ class BeamSearch(paddle.nn.Layer): """ # no pre beam performed, `ids` equal to `weighted_scores` if weighted_scores.size(0) == ids.size(0): - top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab + top_ids = weighted_scores.topk( + self.beam_size)[1] # index in n_vocab return top_ids, top_ids # mask pruned in pre-beam not to select in topk @@ -232,18 +240,18 @@ class BeamSearch(paddle.nn.Layer): weighted_scores[:] = -float("inf") weighted_scores[ids] = tmp # top_ids no equal to local_ids, since ids shape not same - top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab - local_ids = weighted_scores[ids].topk(self.beam_size)[1] # index in len(ids) + top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab + local_ids = weighted_scores[ids].topk( + self.beam_size)[1] # index in len(ids) return top_ids, local_ids @staticmethod def merge_scores( - prev_scores: Dict[str, float], - next_full_scores: Dict[str, paddle.Tensor], - full_idx: int, - next_part_scores: Dict[str, paddle.Tensor], - part_idx: int, - ) -> Dict[str, paddle.Tensor]: + prev_scores: Dict[str, float], + next_full_scores: Dict[str, paddle.Tensor], + full_idx: int, + next_part_scores: Dict[str, paddle.Tensor], + part_idx: int, ) -> Dict[str, paddle.Tensor]: """Merge scores for new hypothesis. Args: @@ -289,9 +297,8 @@ class BeamSearch(paddle.nn.Layer): new_states[k] = d.select_state(part_states[k], part_idx) return new_states - def search( - self, running_hyps: List[Hypothesis], x: paddle.Tensor - ) -> List[Hypothesis]: + def search(self, running_hyps: List[Hypothesis], + x: paddle.Tensor) -> List[Hypothesis]: """Search new tokens for running hypotheses and encoded speech x. Args: @@ -306,17 +313,15 @@ class BeamSearch(paddle.nn.Layer): part_ids = paddle.arange(self.n_vocab) # no pre-beam for hyp in running_hyps: # scoring - weighted_scores = paddle.zeros(self.n_vocab, dtype=x.dtype) + weighted_scores = paddle.zeros([self.n_vocab], dtype=x.dtype) scores, states = self.score_full(hyp, x) for k in self.full_scorers: weighted_scores += self.weights[k] * scores[k] # partial scoring if self.do_pre_beam: - pre_beam_scores = ( - weighted_scores - if self.pre_beam_score_key == "full" - else scores[self.pre_beam_score_key] - ) + pre_beam_scores = (weighted_scores + if self.pre_beam_score_key == "full" else + scores[self.pre_beam_score_key]) part_ids = paddle.topk(pre_beam_scores, self.pre_beam_size)[1] part_scores, part_states = self.score_partial(hyp, part_ids, x) for k in self.part_scorers: @@ -332,22 +337,21 @@ class BeamSearch(paddle.nn.Layer): Hypothesis( score=weighted_scores[j], yseq=self.append_token(hyp.yseq, j), - scores=self.merge_scores( - hyp.scores, scores, j, part_scores, part_j - ), + scores=self.merge_scores(hyp.scores, scores, j, + part_scores, part_j), states=self.merge_states(states, part_states, part_j), - ) - ) + )) # sort and prune 2 x beam -> beam - best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ - : min(len(best_hyps), self.beam_size) - ] + best_hyps = sorted( + best_hyps, key=lambda x: x.score, + reverse=True)[:min(len(best_hyps), self.beam_size)] return best_hyps - def forward( - self, x: paddle.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 - ) -> List[Hypothesis]: + def forward(self, + x: paddle.Tensor, + maxlenratio: float=0.0, + minlenratio: float=0.0) -> List[Hypothesis]: """Perform beam search. Args: @@ -382,9 +386,11 @@ class BeamSearch(paddle.nn.Layer): logger.debug("position " + str(i)) best = self.search(running_hyps, x) # post process of one iteration - running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) + running_hyps = self.post_process(i, maxlen, maxlenratio, best, + ended_hyps) # end detection - if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): + if maxlenratio == 0.0 and end_detect( + [h.asdict() for h in ended_hyps], i): logger.info(f"end detected at {i}") break if len(running_hyps) == 0: @@ -396,41 +402,39 @@ class BeamSearch(paddle.nn.Layer): nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) # check the number of hypotheses reaching to eos if len(nbest_hyps) == 0: - logger.warning( - "there is no N-best results, perform recognition " - "again with smaller minlenratio." - ) - return ( - [] - if minlenratio < 0.1 - else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) - ) + logger.warning("there is no N-best results, perform recognition " + "again with smaller minlenratio.") + return ([] if minlenratio < 0.1 else + self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))) # report the best result best = nbest_hyps[0] for k, v in best.scores.items(): logger.info( - f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" + f"{float(v):6.2f} * {self.weights[k]:3} = {float(v) * self.weights[k]:6.2f} for {k}" ) - logger.info(f"total log probability: {best.score:.2f}") - logger.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") + logger.info(f"total log probability: {float(best.score):.2f}") + logger.info( + f"normalized log probability: {float(best.score) / len(best.yseq):.2f}" + ) logger.info(f"total number of ended hypotheses: {len(nbest_hyps)}") if self.token_list is not None: - logger.info( - "best hypo: " - + "".join([self.token_list[x] for x in best.yseq[1:-1]]) - + "\n" - ) + # logger.info( + # "best hypo: " + # + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + # + "\n" + # ) + logger.info("best hypo: " + "".join( + [self.token_list[x] for x in best.yseq[1:]]) + "\n") return nbest_hyps def post_process( - self, - i: int, - maxlen: int, - maxlenratio: float, - running_hyps: List[Hypothesis], - ended_hyps: List[Hypothesis], - ) -> List[Hypothesis]: + self, + i: int, + maxlen: int, + maxlenratio: float, + running_hyps: List[Hypothesis], + ended_hyps: List[Hypothesis], ) -> List[Hypothesis]: """Perform post-processing of beam search iterations. Args: @@ -446,10 +450,8 @@ class BeamSearch(paddle.nn.Layer): """ logger.debug(f"the number of running hypotheses: {len(running_hyps)}") if self.token_list is not None: - logger.debug( - "best hypo: " - + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) - ) + logger.debug("best hypo: " + "".join( + [self.token_list[x] for x in running_hyps[0].yseq[1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logger.info("adding in the last position in the loop") @@ -464,7 +466,8 @@ class BeamSearch(paddle.nn.Layer): for hyp in running_hyps: if hyp.yseq[-1] == self.eos: # e.g., Word LM needs to add final score - for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): + for k, d in chain(self.full_scorers.items(), + self.part_scorers.items()): s = d.final_score(hyp.states[k]) hyp.scores[k] += s hyp = hyp._replace(score=hyp.score + self.weights[k] * s) @@ -475,19 +478,18 @@ class BeamSearch(paddle.nn.Layer): def beam_search( - x: paddle.Tensor, - sos: int, - eos: int, - beam_size: int, - vocab_size: int, - scorers: Dict[str, ScorerInterface], - weights: Dict[str, float], - token_list: List[str] = None, - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - pre_beam_ratio: float = 1.5, - pre_beam_score_key: str = "full", -) -> list: + x: paddle.Tensor, + sos: int, + eos: int, + beam_size: int, + vocab_size: int, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + token_list: List[str]=None, + maxlenratio: float=0.0, + minlenratio: float=0.0, + pre_beam_ratio: float=1.5, + pre_beam_score_key: str="full", ) -> list: """Perform beam search with scorers. Args: @@ -523,6 +525,6 @@ def beam_search( pre_beam_score_key=pre_beam_score_key, sos=sos, eos=eos, - token_list=token_list, - ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio) + token_list=token_list, ).forward( + x=x, maxlenratio=maxlenratio, minlenratio=minlenratio) return [h.asdict() for h in ret] diff --git a/deepspeech/decoders/recog.py b/deepspeech/decoders/recog.py new file mode 100644 index 00000000..c8df65d6 --- /dev/null +++ b/deepspeech/decoders/recog.py @@ -0,0 +1,195 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`.""" +import json +from pathlib import Path + +import jsonlines +import paddle +import yaml +from yacs.config import CfgNode + +from .beam_search import BatchBeamSearch +from .beam_search import BeamSearch +from .scorers.length_bonus import LengthBonus +from .scorers.scorer_interface import BatchScorerInterface +from .utils import add_results_to_json +from deepspeech.exps import dynamic_import_tester +from deepspeech.io.reader import LoadInputsAndTargets +from deepspeech.models.asr_interface import ASRInterface +from deepspeech.utils.log import Log +# from espnet.asr.asr_utils import get_model_conf +# from espnet.asr.asr_utils import torch_load +# from espnet.nets.lm_interface import dynamic_import_lm + +logger = Log(__name__).getlog() + +# NOTE: you need this func to generate our sphinx doc + + +def load_trained_model(args): + args.nprocs = args.ngpu + confs = CfgNode() + confs.set_new_allowed(True) + confs.merge_from_file(args.model_conf) + class_obj = dynamic_import_tester(args.model_name) + exp = class_obj(confs, args) + with exp.eval(): + exp.setup() + exp.restore() + char_list = exp.args.char_list + model = exp.model + return model, char_list, exp, confs + + +def recog_v2(args): + """Decode with custom models that implements ScorerInterface. + + Args: + args (namespace): The program arguments. + See py:func:`bin.asr_recog.get_parser` for details + + """ + logger.warning("experimental API for custom LMs is selected by --api v2") + if args.batchsize > 1: + raise NotImplementedError("multi-utt batch decoding is not implemented") + if args.streaming_mode is not None: + raise NotImplementedError("streaming mode is not implemented") + if args.word_rnnlm: + raise NotImplementedError("word LM is not implemented") + + # set_deterministic(args) + model, char_list, exp, confs = load_trained_model(args) + assert isinstance(model, ASRInterface) + + load_inputs_and_targets = LoadInputsAndTargets( + mode="asr", + load_output=False, + sort_in_input_length=False, + preprocess_conf=confs.collator.augmentation_config + if args.preprocess_conf is None else args.preprocess_conf, + preprocess_args={"train": False}, + ) + + if args.rnnlm: + lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) + # NOTE: for a compatibility with less than 0.5.0 version models + lm_model_module = getattr(lm_args, "model_module", "default") + lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) + lm = lm_class(len(char_list), lm_args) + torch_load(args.rnnlm, lm) + lm.eval() + else: + lm = None + + if args.ngram_model: + from .scorers.ngram import NgramFullScorer + from .scorers.ngram import NgramPartScorer + + if args.ngram_scorer == "full": + ngram = NgramFullScorer(args.ngram_model, char_list) + else: + ngram = NgramPartScorer(args.ngram_model, char_list) + else: + ngram = None + + scorers = model.scorers() # decoder + scorers["lm"] = lm + scorers["ngram"] = ngram + scorers["length_bonus"] = LengthBonus(len(char_list)) + weights = dict( + decoder=1.0 - args.ctc_weight, + ctc=args.ctc_weight, + lm=args.lm_weight, + ngram=args.ngram_weight, + length_bonus=args.penalty, + ) + beam_search = BeamSearch( + beam_size=args.beam_size, + vocab_size=len(char_list), + weights=weights, + scorers=scorers, + sos=model.sos, + eos=model.eos, + token_list=char_list, + pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", + ) + + # TODO(karita): make all scorers batchfied + if args.batchsize == 1: + non_batch = [ + k for k, v in beam_search.full_scorers.items() + if not isinstance(v, BatchScorerInterface) + ] + if len(non_batch) == 0: + beam_search.__class__ = BatchBeamSearch + logger.info("BatchBeamSearch implementation is selected.") + else: + logger.warning(f"As non-batch scorers {non_batch} are found, " + f"fall back to non-batch implementation.") + + if args.ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + if args.ngpu == 1: + device = "gpu:0" + else: + device = "cpu" + paddle.set_device(device) + dtype = getattr(paddle, args.dtype) + logger.info(f"Decoding device={device}, dtype={dtype}") + model.to(device=device, dtype=dtype) + model.eval() + beam_search.to(device=device, dtype=dtype) + beam_search.eval() + + # read json data + js = [] + with jsonlines.open(args.recog_json, "r") as reader: + for item in reader: + js.append(item) + # jsonlines to dict, key by 'utt', value by jsonline + js = {item['utt']: item for item in js} + + new_js = {} + with paddle.no_grad(): + with jsonlines.open(args.result_label, "w") as f: + for idx, name in enumerate(js.keys(), 1): + logger.info(f"({idx}/{len(js.keys())}) decoding " + name) + batch = [(name, js[name])] + feat = load_inputs_and_targets(batch)[0][0] + logger.info(f'feat: {feat.shape}') + enc = model.encode(paddle.to_tensor(feat).to(dtype)) + logger.info(f'eout: {enc.shape}') + nbest_hyps = beam_search(x=enc, + maxlenratio=args.maxlenratio, + minlenratio=args.minlenratio) + nbest_hyps = [ + h.asdict() + for h in nbest_hyps[:min(len(nbest_hyps), args.nbest)] + ] + new_js[name] = add_results_to_json(js[name], nbest_hyps, + char_list) + + item = new_js[name]['output'][0] # 1-best + ref = item['text'] + rec_text = item['rec_text'].replace('▁', + ' ').replace('', + '').strip() + rec_tokenid = list(map(int, item['rec_tokenid'].split())) + f.write({ + "utt": name, + "refs": [ref], + "hyps": [rec_text], + "hyps_tokenid": [rec_tokenid], + }) diff --git a/deepspeech/decoders/recog_bin.py b/deepspeech/decoders/recog_bin.py new file mode 100644 index 00000000..567dfecd --- /dev/null +++ b/deepspeech/decoders/recog_bin.py @@ -0,0 +1,376 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""End-to-end speech recognition model decoding script.""" +import logging +import os +import random +import sys +from distutils.util import strtobool + +import configargparse +import numpy as np + +from .recog import recog_v2 + + +def get_parser(): + """Get default arguments.""" + parser = configargparse.ArgumentParser( + description="Transcribe text from speech using " + "a speech recognition model on one CPU or GPU", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, ) + parser.add( + '--model-name', + type=str, + default='u2_kaldi', + help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st') + # general configuration + parser.add("--config", is_config_file=True, help="Config file path") + parser.add( + "--config2", + is_config_file=True, + help="Second config file path that overwrites the settings in `--config`", + ) + parser.add( + "--config3", + is_config_file=True, + help="Third config file path that overwrites the settings " + "in `--config` and `--config2`", ) + + parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs") + parser.add_argument( + "--dtype", + choices=("float16", "float32", "float64"), + default="float32", + help="Float precision (only available in --api v2)", ) + parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") + parser.add_argument("--seed", type=int, default=1, help="Random seed") + parser.add_argument( + "--verbose", "-V", type=int, default=2, help="Verbose option") + parser.add_argument( + "--batchsize", + type=int, + default=1, + help="Batch size for beam search (0: means no batch processing)", ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", ) + parser.add_argument( + "--api", + default="v2", + choices=["v2"], + help="Beam search APIs " + "v2: Experimental API. It supports any models that implements ScorerInterface.", + ) + # task related + parser.add_argument( + "--recog-json", type=str, help="Filename of recognition data (json)") + parser.add_argument( + "--result-label", + type=str, + required=True, + help="Filename of result label data (json)", ) + # model (parameter) related + parser.add_argument( + "--model", + type=str, + required=True, + help="Model file parameters to read") + parser.add_argument( + "--model-conf", type=str, default=None, help="Model config file") + parser.add_argument( + "--num-spkrs", + type=int, + default=1, + choices=[1, 2], + help="Number of speakers in the speech", ) + parser.add_argument( + "--num-encs", + default=1, + type=int, + help="Number of encoders in the model.") + # search related + parser.add_argument( + "--nbest", type=int, default=1, help="Output N-best hypotheses") + parser.add_argument("--beam-size", type=int, default=1, help="Beam size") + parser.add_argument( + "--penalty", type=float, default=0.0, help="Incertion penalty") + parser.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="""Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths. + If maxlenratio<0.0, its absolute value is interpreted + as a constant max output length""", ) + parser.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", ) + parser.add_argument( + "--ctc-weight", + type=float, + default=0.0, + help="CTC weight in joint decoding") + parser.add_argument( + "--weights-ctc-dec", + type=float, + action="append", + help="ctc weight assigned to each encoder during decoding." + "[in multi-encoder mode only]", ) + parser.add_argument( + "--ctc-window-margin", + type=int, + default=0, + help="""Use CTC window with margin parameter to accelerate + CTC/attention decoding especially on GPU. Smaller magin + makes decoding faster, but may increase search errors. + If margin=0 (default), this function is disabled""", ) + # transducer related + parser.add_argument( + "--search-type", + type=str, + default="default", + choices=["default", "nsc", "tsd", "alsd", "maes"], + help="""Type of beam search implementation to use during inference. + Can be either: default beam search ("default"), + N-Step Constrained beam search ("nsc"), Time-Synchronous Decoding ("tsd"), + Alignment-Length Synchronous Decoding ("alsd") or + modified Adaptive Expansion Search ("maes").""", ) + parser.add_argument( + "--nstep", + type=int, + default=1, + help="""Number of expansion steps allowed in NSC beam search or mAES + (nstep > 0 for NSC and nstep > 1 for mAES).""", ) + parser.add_argument( + "--prefix-alpha", + type=int, + default=2, + help="Length prefix difference allowed in NSC beam search or mAES.", ) + parser.add_argument( + "--max-sym-exp", + type=int, + default=2, + help="Number of symbol expansions allowed in TSD.", ) + parser.add_argument( + "--u-max", + type=int, + default=400, + help="Length prefix difference allowed in ALSD.", ) + parser.add_argument( + "--expansion-gamma", + type=float, + default=2.3, + help="Allowed logp difference for prune-by-value method in mAES.", ) + parser.add_argument( + "--expansion-beta", + type=int, + default=2, + help="""Number of additional candidates for expanded hypotheses + selection in mAES.""", ) + parser.add_argument( + "--score-norm", + type=strtobool, + nargs="?", + default=True, + help="Normalize final hypotheses' score by length", ) + parser.add_argument( + "--softmax-temperature", + type=float, + default=1.0, + help="Penalization term for softmax function.", ) + # rnnlm related + parser.add_argument( + "--rnnlm", type=str, default=None, help="RNNLM model file to read") + parser.add_argument( + "--rnnlm-conf", + type=str, + default=None, + help="RNNLM model config file to read") + parser.add_argument( + "--word-rnnlm", + type=str, + default=None, + help="Word RNNLM model file to read") + parser.add_argument( + "--word-rnnlm-conf", + type=str, + default=None, + help="Word RNNLM model config file to read", ) + parser.add_argument( + "--word-dict", type=str, default=None, help="Word list to read") + parser.add_argument( + "--lm-weight", type=float, default=0.1, help="RNNLM weight") + # ngram related + parser.add_argument( + "--ngram-model", + type=str, + default=None, + help="ngram model file to read") + parser.add_argument( + "--ngram-weight", type=float, default=0.1, help="ngram weight") + parser.add_argument( + "--ngram-scorer", + type=str, + default="part", + choices=("full", "part"), + help="""if the ngram is set as a part scorer, similar with CTC scorer, + ngram scorer only scores topK hypethesis. + if the ngram is set as full scorer, ngram scorer scores all hypthesis + the decoding speed of part scorer is musch faster than full one""", + ) + # streaming related + parser.add_argument( + "--streaming-mode", + type=str, + default=None, + choices=["window", "segment"], + help="""Use streaming recognizer for inference. + `--batchsize` must be set to 0 to enable this mode""", ) + parser.add_argument( + "--streaming-window", type=int, default=10, help="Window size") + parser.add_argument( + "--streaming-min-blank-dur", + type=int, + default=10, + help="Minimum blank duration threshold", ) + parser.add_argument( + "--streaming-onset-margin", type=int, default=1, help="Onset margin") + parser.add_argument( + "--streaming-offset-margin", type=int, default=1, help="Offset margin") + # non-autoregressive related + # Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail. + parser.add_argument( + "--maskctc-n-iterations", + type=int, + default=10, + help="Number of decoding iterations." + "For Mask CTC, set 0 to predict 1 mask/iter.", ) + parser.add_argument( + "--maskctc-probability-threshold", + type=float, + default=0.999, + help="Threshold probability for CTC output", ) + # quantize model related + parser.add_argument( + "--quantize-config", + nargs="*", + help="Quantize config list. E.g.: --quantize-config=[Linear,LSTM,GRU]", + ) + parser.add_argument( + "--quantize-dtype", + type=str, + default="qint8", + help="Dtype dynamic quantize") + parser.add_argument( + "--quantize-asr-model", + type=bool, + default=False, + help="Quantize asr model", ) + parser.add_argument( + "--quantize-lm-model", + type=bool, + default=False, + help="Quantize lm model", ) + return parser + + +def main(args): + """Run the main decoding function.""" + parser = get_parser() + parser.add_argument( + "--output", metavar="CKPT_DIR", help="path to save checkpoint.") + parser.add_argument( + "--checkpoint_path", type=str, help="path to load checkpoint") + parser.add_argument("--dict-path", type=str, help="path to load checkpoint") + args = parser.parse_args(args) + + if args.ngpu == 0 and args.dtype == "float16": + raise ValueError( + f"--dtype {args.dtype} does not support the CPU backend.") + + # logging info + if args.verbose == 1: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose == 2: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + logging.info(args) + + # check CUDA_VISIBLE_DEVICES + if args.ngpu > 0: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is None: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + elif args.ngpu != len(cvd.split(",")): + logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") + sys.exit(1) + + # TODO(mn5k): support of multiple GPUs + if args.ngpu > 1: + logging.error("The program only supports ngpu=1.") + sys.exit(1) + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # seed setting + random.seed(args.seed) + np.random.seed(args.seed) + logging.info("set random seed = %d" % args.seed) + + # validate rnn options + if args.rnnlm is not None and args.word_rnnlm is not None: + logging.error( + "It seems that both --rnnlm and --word-rnnlm are specified. " + "Please use either option.") + sys.exit(1) + + # recog + if args.num_spkrs == 1: + if args.num_encs == 1: + # Experimental API that supports custom LMs + if args.api == "v2": + from deepspeech.decoders.recog import recog_v2 + recog_v2(args) + else: + raise ValueError("Only support --api v2") + else: + if args.api == "v2": + raise NotImplementedError( + f"--num-encs {args.num_encs} > 1 is not supported in --api v2" + ) + elif args.num_spkrs == 2: + raise ValueError("asr_mix not supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/deepspeech/decoders/scorers/ctc.py b/deepspeech/decoders/scorers/ctc.py index 36b4bfd3..4871d6e1 100644 --- a/deepspeech/decoders/scorers/ctc.py +++ b/deepspeech/decoders/scorers/ctc.py @@ -15,8 +15,8 @@ import numpy as np import paddle -from .ctc_prefix_score import CTCPrefixScorer -from .ctc_prefix_score import CTCPrefixScorerPD +from .ctc_prefix_score import CTCPrefixScore +from .ctc_prefix_score import CTCPrefixScorePD from .scorer_interface import BatchPartialScorerInterface diff --git a/deepspeech/decoders/scorers/ctc_prefix_score.py b/deepspeech/decoders/scorers/ctc_prefix_score.py index 5f568c81..c85d546d 100644 --- a/deepspeech/decoders/scorers/ctc_prefix_score.py +++ b/deepspeech/decoders/scorers/ctc_prefix_score.py @@ -6,7 +6,7 @@ import paddle import six -class CTCPrefixScorerPD(): +class CTCPrefixScorePD(): """Batch processing of CTCPrefixScore which is based on Algorithm 2 in WATANABE et al. @@ -267,7 +267,7 @@ class CTCPrefixScorerPD(): return (r_prev_new, s_prev, f_min_prev, f_max_prev) -class CTCPrefixScorer(): +class CTCPrefixScore(): """Compute CTC label sequence scores which is based on Algorithm 2 in WATANABE et al. diff --git a/deepspeech/decoders/scorers/ngram.py b/deepspeech/decoders/scorers/ngram.py index 050a8c81..a34d8248 100644 --- a/deepspeech/decoders/scorers/ngram.py +++ b/deepspeech/decoders/scorers/ngram.py @@ -85,8 +85,9 @@ class NgramFullScorer(Ngrambase, BatchScorerInterface): and next state list for ys. """ - return self.score_partial_( - y, paddle.to_tensor(range(self.charlen)), state, x) + return self.score_partial_(y, + paddle.to_tensor(range(self.charlen)), state, + x) class NgramPartScorer(Ngrambase, PartialScorerInterface): diff --git a/deepspeech/decoders/scorers/score_interface.py b/deepspeech/decoders/scorers/scorer_interface.py similarity index 100% rename from deepspeech/decoders/scorers/score_interface.py rename to deepspeech/decoders/scorers/scorer_interface.py diff --git a/deepspeech/decoders/utils.py b/deepspeech/decoders/utils.py index 92f65814..3ed9c5da 100644 --- a/deepspeech/decoders/utils.py +++ b/deepspeech/decoders/utils.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np -__all__ = ["end_detect"] +from deepspeech.utils.log import Log +logger = Log(__name__).getlog() + +__all__ = ["end_detect", "parse_hypothesis", "add_results_to_json"] def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): @@ -47,3 +51,79 @@ def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): return True else: return False + + +# * ------------------ recognition related ------------------ * +def parse_hypothesis(hyp, char_list): + """Parse hypothesis. + + Args: + hyp (list[dict[str, Any]]): Recognition hypothesis. + char_list (list[str]): List of characters. + + Returns: + tuple(str, str, str, float) + + """ + # remove sos and get results + tokenid_as_list = list(map(int, hyp["yseq"][1:])) + token_as_list = [char_list[idx] for idx in tokenid_as_list] + score = float(hyp["score"]) + + # convert to string + tokenid = " ".join([str(idx) for idx in tokenid_as_list]) + token = " ".join(token_as_list) + text = "".join(token_as_list).replace("", " ") + + return text, token, tokenid, score + + +def add_results_to_json(js, nbest_hyps, char_list): + """Add N-best results to json. + + Args: + js (dict[str, Any]): Groundtruth utterance dict. + nbest_hyps_sd (list[dict[str, Any]]): + List of hypothesis for multi_speakers: nutts x nspkrs. + char_list (list[str]): List of characters. + + Returns: + dict[str, Any]: N-best results added utterance dict. + + """ + # copy old json info + new_js = dict() + new_js["utt2spk"] = js["utt2spk"] + new_js["output"] = [] + + for n, hyp in enumerate(nbest_hyps, 1): + # parse hypothesis + rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, + char_list) + + # copy ground-truth + if len(js["output"]) > 0: + out_dic = dict(js["output"][0].items()) + else: + # for no reference case (e.g., speech translation) + out_dic = {"name": ""} + + # update name + out_dic["name"] += "[%d]" % n + + # add recognition results + out_dic["rec_text"] = rec_text + out_dic["rec_token"] = rec_token + out_dic["rec_tokenid"] = rec_tokenid + out_dic["score"] = score + + # add to list of N-best result dicts + new_js["output"].append(out_dic) + + # show 1-best result + if n == 1: + if "text" in out_dic.keys(): + logger.info("groundtruth: %s" % out_dic["text"]) + logger.info("prediction : %s" % out_dic["rec_text"]) + + return new_js diff --git a/deepspeech/exps/__init__.py b/deepspeech/exps/__init__.py index 185a92b8..29953014 100644 --- a/deepspeech/exps/__init__.py +++ b/deepspeech/exps/__init__.py @@ -11,3 +11,52 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from deepspeech.training.trainer import Trainer +from deepspeech.utils.dynamic_import import dynamic_import + +model_trainer_alias = { + "ds2": "deepspeech.exp.deepspeech2.model:DeepSpeech2Trainer", + "u2": "deepspeech.exps.u2.model:U2Trainer", + "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Trainer", + "u2_st": "deepspeech.exps.u2_st.model:U2STTrainer", +} + + +def dynamic_import_trainer(module): + """Import Trainer dynamically. + + Args: + module (str): trainer name. e.g., ds2, u2, u2_kaldi + + Returns: + type: Trainer class + + """ + model_class = dynamic_import(module, model_trainer_alias) + assert issubclass(model_class, + Trainer), f"{module} does not implement Trainer" + return model_class + + +model_tester_alias = { + "ds2": "deepspeech.exp.deepspeech2.model:DeepSpeech2Tester", + "u2": "deepspeech.exps.u2.model:U2Tester", + "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester", + "u2_st": "deepspeech.exps.u2_st.model:U2STTester", +} + + +def dynamic_import_tester(module): + """Import Tester dynamically. + + Args: + module (str): tester name. e.g., ds2, u2, u2_kaldi + + Returns: + type: Tester class + + """ + model_class = dynamic_import(module, model_tester_alias) + assert issubclass(model_class, + Trainer), f"{module} does not implement Tester" + return model_class diff --git a/deepspeech/exps/u2_kaldi/bin/recog.py b/deepspeech/exps/u2_kaldi/bin/recog.py new file mode 100644 index 00000000..e94a1ab1 --- /dev/null +++ b/deepspeech/exps/u2_kaldi/bin/recog.py @@ -0,0 +1,19 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +from deepspeech.decoders.recog_bin import main + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 18e29b28..f8624326 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -317,11 +317,9 @@ class U2Trainer(Trainer): with UpdateConfig(model_conf): model_conf.input_dim = self.train_loader.feat_dim model_conf.output_dim = self.train_loader.vocab_size - model = U2Model.from_config(model_conf) if self.parallel: model = paddle.DataParallel(model) - logger.info(f"{model}") layer_tools.print_params(model, logger.info) # lr @@ -436,8 +434,9 @@ class U2Tester(U2Trainer): simulate_streaming=cfg.simulate_streaming) decode_time = time.time() - start_time - for i, (utt, target, result, rec_tids) in enumerate(zip( - utts, target_transcripts, result_transcripts, result_tokenids)): + for i, (utt, target, result, rec_tids) in enumerate( + zip(utts, target_transcripts, result_transcripts, + result_tokenids)): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index 34220432..a6834ebc 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -140,7 +140,7 @@ class TextFeaturizer(): Returns: str: text string. """ - tokens = [t.replace(SPACE, " ") for t in tokens ] + tokens = [t.replace(SPACE, " ") for t in tokens] return "".join(tokens) def word_tokenize(self, text): @@ -207,7 +207,7 @@ class TextFeaturizer(): """Load vocabulary from file.""" vocab_list = load_dict(vocab_filepath, maskctc) assert vocab_list is not None - logger.info(f"Vocab: {pformat(vocab_list)}") + logger.debug(f"Vocab: {pformat(vocab_list)}") id2token = dict( [(idx, token) for (idx, token) in enumerate(vocab_list)]) diff --git a/deepspeech/models/asr_interface.py b/deepspeech/models/asr_interface.py new file mode 100644 index 00000000..7dac81b4 --- /dev/null +++ b/deepspeech/models/asr_interface.py @@ -0,0 +1,161 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ASR Interface module.""" +import argparse + +from deepspeech.utils.dynamic_import import dynamic_import + + +class ASRInterface: + """ASR Interface for ESPnet model implementation.""" + + @staticmethod + def add_arguments(parser): + """Add arguments to parser.""" + return parser + + @classmethod + def build(cls, idim: int, odim: int, **kwargs): + """Initialize this class with python-level args. + + Args: + idim (int): The number of an input feature dim. + odim (int): The number of output vocab. + + Returns: + ASRinterface: A new instance of ASRInterface. + + """ + args = argparse.Namespace(**kwargs) + return cls(idim, odim, args) + + def forward(self, xs, ilens, ys, olens): + """Compute loss for training. + + :param xs: batch of padded source sequences paddle.Tensor (B, Tmax, idim) + :param ilens: batch of lengths of source sequences (B), paddle.Tensor + :param ys: batch of padded target sequences paddle.Tensor (B, Lmax) + :param olens: batch of lengths of target sequences (B), paddle.Tensor + :return: loss value + :rtype: paddle.Tensor + """ + raise NotImplementedError("forward method is not implemented") + + def recognize(self, x, recog_args, char_list=None, rnnlm=None): + """Recognize x for evaluation. + + :param ndarray x: input acouctic feature (B, T, D) or (T, D) + :param namespace recog_args: argment namespace contraining options + :param list char_list: list of characters + :param paddle.nn.Layer rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("recognize method is not implemented") + + def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None): + """Beam search implementation for batch. + + :param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc) + :param namespace recog_args: argument namespace containing options + :param list char_list: list of characters + :param paddle.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("Batch decoding is not supported yet.") + + def calculate_all_attentions(self, xs, ilens, ys): + """Calculate attention. + + :param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...] + :param ndarray ilens: batch of lengths of input sequences (B) + :param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...] + :return: attention weights (B, Lmax, Tmax) + :rtype: float ndarray + """ + raise NotImplementedError( + "calculate_all_attentions method is not implemented") + + def calculate_all_ctc_probs(self, xs, ilens, ys): + """Calculate CTC probability. + + :param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...] + :param ndarray ilens: batch of lengths of input sequences (B) + :param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...] + :return: CTC probabilities (B, Tmax, vocab) + :rtype: float ndarray + """ + raise NotImplementedError( + "calculate_all_ctc_probs method is not implemented") + + @property + def attention_plot_class(self): + """Get attention plot class.""" + from espnet.asr.asr_utils import PlotAttentionReport + + return PlotAttentionReport + + @property + def ctc_plot_class(self): + """Get CTC plot class.""" + from espnet.asr.asr_utils import PlotCTCReport + + return PlotCTCReport + + def get_total_subsampling_factor(self): + """Get total subsampling factor.""" + raise NotImplementedError( + "get_total_subsampling_factor method is not implemented") + + def encode(self, feat): + """Encode feature in `beam_search` (optional). + + Args: + x (numpy.ndarray): input feature (T, D) + Returns: + paddle.Tensor: encoded feature (T, D) + """ + raise NotImplementedError("encode method is not implemented") + + def scorers(self): + """Get scorers for `beam_search` (optional). + + Returns: + dict[str, ScorerInterface]: dict of `ScorerInterface` objects + + """ + raise NotImplementedError("decoders method is not implemented") + + +predefined_asr = { + "transformer": "deepspeech.models.u2:U2Model", + "conformer": "deepspeech.models.u2:U2Model", +} + + +def dynamic_import_asr(module): + """Import ASR models dynamically. + + Args: + module (str): asr name. e.g., transformer, conformer + + Returns: + type: ASR class + + """ + model_class = dynamic_import(module, predefined_asr) + assert issubclass(model_class, + ASRInterface), f"{module} does not implement ASRInterface" + return model_class diff --git a/deepspeech/models/u2/u2.py b/deepspeech/models/u2/u2.py index fd63fa9c..6cd3b775 100644 --- a/deepspeech/models/u2/u2.py +++ b/deepspeech/models/u2/u2.py @@ -28,8 +28,10 @@ from paddle import jit from paddle import nn from yacs.config import CfgNode +from deepspeech.decoders.scorers.ctc import CTCPrefixScorer from deepspeech.frontend.utility import IGNORE_ID from deepspeech.frontend.utility import load_cmvn +from deepspeech.models.asr_interface import ASRInterface from deepspeech.modules.cmvn import GlobalCMVN from deepspeech.modules.ctc import CTCDecoder from deepspeech.modules.decoder import TransformerDecoder @@ -55,7 +57,7 @@ __all__ = ["U2Model", "U2InferModel"] logger = Log(__name__).getlog() -class U2BaseModel(nn.Layer): +class U2BaseModel(ASRInterface, nn.Layer): """CTC-Attention hybrid Encoder-Decoder model""" @classmethod @@ -120,7 +122,7 @@ class U2BaseModel(nn.Layer): **kwargs): assert 0.0 <= ctc_weight <= 1.0, ctc_weight - super().__init__() + nn.Layer.__init__(self) # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 @@ -813,7 +815,27 @@ class U2BaseModel(nn.Layer): return res, res_tokenids -class U2Model(U2BaseModel): +class U2DecodeModel(U2BaseModel): + def scorers(self): + """Scorers.""" + return dict( + decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) + + def encode(self, x): + """Encode acoustic features. + + :param ndarray x: source acoustic feature (T, D) + :return: encoder outputs + :rtype: paddle.Tensor + """ + self.eval() + x = paddle.to_tensor(x).unsqueeze(0) + ilen = x.size(1) + enc_output, _ = self._forward_encoder(x, ilen) + return enc_output.squeeze(0) + + +class U2Model(U2DecodeModel): def __init__(self, configs: dict): vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py index 1ae3ce37..735f06dc 100644 --- a/deepspeech/modules/decoder.py +++ b/deepspeech/modules/decoder.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Decoder definition.""" +from typing import Any from typing import List from typing import Optional from typing import Tuple @@ -20,10 +21,12 @@ import paddle from paddle import nn from typeguard import check_argument_types +from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface from deepspeech.modules.attention import MultiHeadedAttention from deepspeech.modules.decoder_layer import DecoderLayer from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.mask import make_non_pad_mask +from deepspeech.modules.mask import make_xs_mask from deepspeech.modules.mask import subsequent_mask from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward from deepspeech.utils.log import Log @@ -33,7 +36,7 @@ logger = Log(__name__).getlog() __all__ = ["TransformerDecoder"] -class TransformerDecoder(nn.Layer): +class TransformerDecoder(BatchScorerInterface, nn.Layer): """Base class of Transfomer decoder module. Args: vocab_size: output dim @@ -71,7 +74,8 @@ class TransformerDecoder(nn.Layer): concat_after: bool=False, ): assert check_argument_types() - super().__init__() + nn.Layer.__init__(self) + self.selfattention_layer_type = 'selfattn' attention_dim = encoder_output_size if input_layer == "embed": @@ -180,3 +184,63 @@ class TransformerDecoder(nn.Layer): if self.use_output_layer: y = paddle.log_softmax(self.output_layer(y), axis=-1) return y, new_cache + + # beam search API (see ScorerInterface) + def score(self, ys, state, x): + """Score. + ys: (ylen,) + x: (xlen, n_feat) + """ + ys_mask = subsequent_mask(len(ys)).unsqueeze(0) # (B,L,L) + x_mask = make_xs_mask(x.unsqueeze(0)).unsqueeze(1) # (B,1,T) + if self.selfattention_layer_type != "selfattn": + # TODO(karita): implement cache + logging.warning( + f"{self.selfattention_layer_type} does not support cached decoding." + ) + state = None + logp, state = self.forward_one_step( + x.unsqueeze(0), x_mask, ys.unsqueeze(0), ys_mask, cache=state) + return logp.squeeze(0), state + + # batch beam search API (see BatchScorerInterface) + def batch_score(self, + ys: paddle.Tensor, + states: List[Any], + xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]: + """Score new token batch (required). + + Args: + ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (paddle.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[paddle.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + # merge states + n_batch = len(ys) + n_layers = len(self.decoders) + if states[0] is None: + batch_state = None + else: + # transpose state of [batch, layer] into [layer, batch] + batch_state = [ + paddle.stack([states[b][i] for b in range(n_batch)]) + for i in range(n_layers) + ] + + # batch decoding + ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) # (B,L,L) + xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T) + logp, states = self.forward_one_step( + xs, xs_mask, ys, ys_mask, cache=batch_state) + + # transpose state of [layer, batch] into [batch, layer] + state_list = [[states[i][b] for i in range(n_layers)] + for b in range(n_batch)] + return logp, state_list diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index 00f228a2..52f8e4bc 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -18,12 +18,25 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() __all__ = [ - "make_pad_mask", "make_non_pad_mask", "subsequent_mask", + "make_xs_mask", "make_pad_mask", "make_non_pad_mask", "subsequent_mask", "subsequent_chunk_mask", "add_optional_chunk_mask", "mask_finished_scores", "mask_finished_preds" ] +def make_xs_mask(xs: paddle.Tensor, pad_value=0.0) -> paddle.Tensor: + """Maks mask tensor containing indices of non-padded part. + Args: + xs (paddle.Tensor): (B, T, D), zeros for pad. + Returns: + paddle.Tensor: Mask Tensor indices of non-padded part. (B, T) + """ + pad_frame = paddle.full([1, 1, xs.shape[-1]], pad_value, dtype=xs.dtype) + mask = xs != pad_frame + mask = mask.all(axis=-1) + return mask + + def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: """Make mask tensor containing indices of padded part. See description of make_non_pad_mask. @@ -31,6 +44,7 @@ def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: lengths (paddle.Tensor): Batch of lengths (B,). Returns: paddle.Tensor: Mask tensor containing indices of padded part. + (B, T) Examples: >>> lengths = [5, 3, 2] >>> make_pad_mask(lengths) @@ -62,6 +76,7 @@ def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor: lengths (paddle.Tensor): Batch of lengths (B,). Returns: paddle.Tensor: mask tensor containing indices of padded part. + (B, T) Examples: >>> lengths = [5, 3, 2] >>> make_non_pad_mask(lengths) diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index ca9652e3..14a34cb7 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -35,7 +35,7 @@ class LoadFromFile(argparse.Action): parser.parse_args(f.read().split(), namespace) -def default_argument_parser(): +def default_argument_parser(parser=None): r"""A simple yet genral argument parser for experiments with parakeet. This is used in examples with parakeet. And it is intended to be used by @@ -62,7 +62,9 @@ def default_argument_parser(): argparse.ArgumentParser the parser """ - parser = argparse.ArgumentParser() + if parser is None: + parser = argparse.ArgumentParser() + parser.register('action', 'extend', ExtendAction) parser.add_argument( '--conf', type=open, action=LoadFromFile, help="config file.") diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 6c18fa36..2c238920 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -126,7 +126,8 @@ class Trainer(): logger.info(f"Set seed {args.seed}") # profiler and benchmark options - if self.args.benchmark_batch_size: + if hasattr(self.args, + "benchmark_batch_size") and self.args.benchmark_batch_size: with UpdateConfig(self.config): self.config.collator.batch_size = self.args.benchmark_batch_size self.config.training.log_interval = 1 @@ -326,12 +327,24 @@ class Trainer(): finally: self.destory() + def restore(self): + """Resume from latest checkpoint at checkpoints in the output + directory or load a specified checkpoint. + + If ``args.checkpoint_path`` is not None, load the checkpoint, else + resume training. + """ + assert self.args.checkpoint_path + infos = self.checkpoint.load_latest_parameters( + self.model, checkpoint_path=self.args.checkpoint_path) + return infos + def run_test(self): """Do Test/Decode""" try: with Timer("Test/Decode Done: {}"): with self.eval(): - self.resume_or_scratch() + self.restore() self.test() except KeyboardInterrupt: exit(-1) @@ -341,6 +354,7 @@ class Trainer(): try: with Timer("Export Done: {}"): with self.eval(): + self.restore() self.export() except KeyboardInterrupt: exit(-1) @@ -350,7 +364,7 @@ class Trainer(): try: with Timer("Align Done: {}"): with self.eval(): - self.resume_or_scratch() + self.restore() self.align() except KeyboardInterrupt: sys.exit(-1) diff --git a/examples/librispeech/README.md b/examples/librispeech/README.md index 72459095..5943cf1d 100644 --- a/examples/librispeech/README.md +++ b/examples/librispeech/README.md @@ -1,8 +1,8 @@ # ASR -* s0 is for deepspeech2 offline -* s1 is for transformer/conformer/U2 -* s2 is for transformer/conformer/U2 w/ kaldi feat, need install Kaldi +* s0 is for deepspeech2 offline +* s1 is for transformer/conformer/U2 +* s2 is for transformer/conformer/U2 w/ kaldi feat, need install Kaldi ## Data | Data Subset | Duration in Seconds | diff --git a/examples/librispeech/s2/README.md b/examples/librispeech/s2/README.md index 34c65c11..d5df37d8 100644 --- a/examples/librispeech/s2/README.md +++ b/examples/librispeech/s2/README.md @@ -1,9 +1,14 @@ # LibriSpeech -## Transformer -| Model | Params | Config | Augmentation| Test Set | Decode Method | Loss | WER % | -| --- | --- | --- | --- | --- | --- | --- | --- | -| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 6.395054340362549 | 4.2 | -| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.395054340362549 | 5.0 | -| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.395054340362549 | | -| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescore | 6.395054340362549 | | +| Model | Params | Config | Augmentation| Loss | +| --- | --- | --- | --- | +| transformer | 32.52 M | conf/transformer.yaml | spec_aug | 6.3197922706604 | + + +| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| test-clean | attention | 2620 | 52576 | 96.4 | 2.5 | 1.1 | 0.4 | 4.0 | 34.7 | +| test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 | +| test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 | +| test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 | +| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 | diff --git a/examples/librispeech/s2/conf/decode/decode.yaml b/examples/librispeech/s2/conf/decode/decode.yaml new file mode 100644 index 00000000..4c702db5 --- /dev/null +++ b/examples/librispeech/s2/conf/decode/decode.yaml @@ -0,0 +1,7 @@ +batchsize: 0 +beam-size: 60 +ctc-weight: 0.0 +lm-weight: 0.0 +maxlenratio: 0.0 +minlenratio: 0.0 +penalty: 0.0 diff --git a/examples/librispeech/s2/conf/decode/decode_all.yaml b/examples/librispeech/s2/conf/decode/decode_all.yaml new file mode 100644 index 00000000..87d5f6d1 --- /dev/null +++ b/examples/librispeech/s2/conf/decode/decode_all.yaml @@ -0,0 +1,7 @@ +batchsize: 0 +beam-size: 60 +ctc-weight: 0.4 +lm-weight: 0.6 +maxlenratio: 0.0 +minlenratio: 0.0 +penalty: 0.0 \ No newline at end of file diff --git a/examples/librispeech/s2/conf/decode/decode_wo_lm.yaml b/examples/librispeech/s2/conf/decode/decode_wo_lm.yaml new file mode 100644 index 00000000..a82dba23 --- /dev/null +++ b/examples/librispeech/s2/conf/decode/decode_wo_lm.yaml @@ -0,0 +1,7 @@ +batchsize: 0 +beam-size: 60 +ctc-weight: 0.4 +lm-weight: 0.0 +maxlenratio: 0.0 +minlenratio: 0.0 +penalty: 0.0 \ No newline at end of file diff --git a/examples/librispeech/s2/local/recog.sh b/examples/librispeech/s2/local/recog.sh new file mode 100755 index 00000000..df3846c0 --- /dev/null +++ b/examples/librispeech/s2/local/recog.sh @@ -0,0 +1,109 @@ +#!/bin/bash + +set -e + +expdir=exp +datadir=data +nj=32 +tag= + +# decode config +decode_config=conf/decode/decode.yaml + +# lm params +lang_model=rnnlm.model.best +lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/ +lmtag='nolm' + +recog_set="test-clean test-other dev-clean dev-other" +recog_set="test-clean" + +# bpemode (unigram or bpe) +nbpe=5000 +bpemode=unigram +bpeprefix="data/bpe_${bpemode}_${nbpe}" +bpemodel=${bpeprefix}.model + +# bin params +config_path=conf/transformer.yaml +dict=data/bpe_unigram_5000_units.txt +ckpt_prefix= + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +if [ -z ${ckpt_prefix} ]; then + echo "usage: $0 --ckpt_prefix ckpt_prefix" + exit 1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +ckpt_dir=$(dirname `dirname ${ckpt_prefix}`) +echo "ckpt dir: ${ckpt_dir}" + +ckpt_tag=$(basename ${ckpt_prefix}) +echo "ckpt tag: ${ckpt_tag}" + +chunk_mode=false +if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then + chunk_mode=true +fi +echo "chunk mode: ${chunk_mode}" +echo "decode conf: ${decode_config}" + +# download language model +#bash local/download_lm_en.sh +#if [ $? -ne 0 ]; then +# exit 1 +#fi + + +pids=() # initialize pids + +for dmethd in join_ctc; do +( + echo "${dmethd} decoding" + for rtask in ${recog_set}; do + ( + echo "${rtask} dataset" + decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag}_${tag} + feat_recog_dir=${datadir} + mkdir -p ${decode_dir} + mkdir -p ${feat_recog_dir} + + # split data + split_json.sh manifest.${rtask} ${nj} + + #### use CPU for decoding + ngpu=0 + + # set batchsize 0 to disable batch decoding + ${decode_cmd} JOB=1:${nj} ${decode_dir}/log/decode.JOB.log \ + python3 -u ${BIN_DIR}/recog.py \ + --api v2 \ + --config ${decode_config} \ + --ngpu ${ngpu} \ + --batchsize 0 \ + --checkpoint_path ${ckpt_prefix} \ + --dict-path ${dict} \ + --recog-json ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} \ + --result-label ${decode_dir}/data.JOB.json \ + --model-conf ${config_path} \ + --model ${ckpt_prefix}.pdparams + + #--rnnlm ${lmexpdir}/${lang_model} \ + + score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict} + + ) & + pids+=($!) # store background pids + i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done + [ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." || true + done +) +done + +echo "Finished" + +exit 0 diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh index 4a1cd238..5f662d29 100755 --- a/examples/librispeech/s2/local/test.sh +++ b/examples/librispeech/s2/local/test.sh @@ -6,7 +6,7 @@ expdir=exp datadir=data nj=32 -lmtag= +lmtag='nolm' recog_set="test-clean test-other dev-clean dev-other" recog_set="test-clean" @@ -17,23 +17,31 @@ bpemode=unigram bpeprefix="data/bpe_${bpemode}_${nbpe}" bpemodel=${bpeprefix}.model -if [ $# != 3 ];then - echo "usage: ${0} config_path dict_path ckpt_path_prefix" - exit -1 +config_path=conf/transformer.yaml +dict=data/bpe_unigram_5000_units.txt +ckpt_prefix= + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +if [ -z ${ckpt_prefix} ]; then + echo "usage: $0 --ckpt_prefix ckpt_prefix" + exit 1 fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -config_path=$1 -dict=$2 -ckpt_prefix=$3 +ckpt_dir=$(dirname `dirname ${ckpt_prefix}`) +echo "ckpt dir: ${ckpt_dir}" + +ckpt_tag=$(basename ${ckpt_prefix}) +echo "ckpt tag: ${ckpt_tag}" chunk_mode=false if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then chunk_mode=true fi -echo "chunk mode ${chunk_mode}" +echo "chunk mode: ${chunk_mode}" # download language model @@ -46,13 +54,13 @@ pids=() # initialize pids for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do ( - echo "${dmethd} decoding" + echo "decode method: ${dmethd}" for rtask in ${recog_set}; do ( - echo "${rtask} dataset" - decode_dir=decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag} + echo "dataset: ${rtask}" + decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag} feat_recog_dir=${datadir} - mkdir -p ${expdir}/${decode_dir} + mkdir -p ${decode_dir} mkdir -p ${feat_recog_dir} # split data @@ -63,7 +71,7 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco # set batchsize 0 to disable batch decoding batch_size=1 - ${decode_cmd} JOB=1:${nj} ${expdir}/${decode_dir}/log/decode.JOB.log \ + ${decode_cmd} JOB=1:${nj} ${decode_dir}/log/decode.JOB.log \ python3 -u ${BIN_DIR}/test.py \ --model-name u2_kaldi \ --run-mode test \ @@ -71,12 +79,12 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco --dict-path ${dict} \ --config ${config_path} \ --checkpoint_path ${ckpt_prefix} \ - --result-file ${expdir}/${decode_dir}/data.JOB.json \ + --result-file ${decode_dir}/data.JOB.json \ --opts decoding.decoding_method ${dmethd} \ --opts decoding.batch_size ${batch_size} \ --opts data.test_manifest ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} - score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${expdir}/${decode_dir} ${dict} + score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict} ) & pids+=($!) # store background pids diff --git a/examples/librispeech/s2/run.sh b/examples/librispeech/s2/run.sh index 8a219381..3c7569fb 100755 --- a/examples/librispeech/s2/run.sh +++ b/examples/librispeech/s2/run.sh @@ -1,4 +1,5 @@ #!/bin/bash + set -e . ./path.sh || exit 1; @@ -7,8 +8,9 @@ set -e stage=0 stop_stage=100 conf_path=conf/transformer.yaml -dict_path=data/train_960_unigram5000_units.txt +dict_path=data/bpe_unigram_5000_units.txt avg_num=10 + source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; avg_ckpt=avg_${avg_num} diff --git a/requirements.txt b/requirements.txt index 2a3f0651..a7310a02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,42 +1,43 @@ +ConfigArgParse coverage editdistance +g2p_en +g2pM gpustat +h5py +inflect +jieba jsonlines kaldiio +librosa +llvmlite loguru +matplotlib +nltk +numba +numpy==1.20.0 +pandas +phkit Pillow +praatio~=4.1 pre-commit pybind11 +pypinyin +pyworld resampy==0.2.2 sacrebleu scipy==1.2.1 sentencepiece snakeviz +soundfile~=0.10 sox tensorboardX textgrid +timer tqdm typeguard -visualdl==2.2.0 -yacs -numpy==1.20.0 -numba -nltk -inflect -librosa unidecode -llvmlite -matplotlib -pandas -soundfile~=0.10 -g2p_en -pypinyin +visualdl==2.2.0 webrtcvad -g2pM -praatio~=4.1 -h5py -timer -pyworld -jieba -phkit +yacs yq diff --git a/setup.py b/setup.py index 07b6aac0..bd982129 100644 --- a/setup.py +++ b/setup.py @@ -11,20 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import inspect import io import os import re +import subprocess as sp import sys from pathlib import Path -import contextlib -import inspect +from setuptools import Command from setuptools import find_packages from setuptools import setup -from setuptools import Command from setuptools.command.develop import develop from setuptools.command.install import install -import subprocess as sp HERE = Path(os.path.abspath(os.path.dirname(__file__))) @@ -40,16 +40,18 @@ def pushd(new_dir): def read(*names, **kwargs): - with io.open(os.path.join(os.path.dirname(__file__), *names), - encoding=kwargs.get("encoding", "utf8")) as fp: + with io.open( + os.path.join(os.path.dirname(__file__), *names), + encoding=kwargs.get("encoding", "utf8")) as fp: return fp.read() def check_call(cmd: str, shell=False, executable=None): try: - sp.check_call(cmd.split(), - shell=shell, - executable="/bin/bash" if shell else executable) + sp.check_call( + cmd.split(), + shell=shell, + executable="/bin/bash" if shell else executable) except sp.CalledProcessError as e: print( f"{__file__}:{inspect.currentframe().f_lineno}: CMD: {cmd}, Error:", @@ -189,7 +191,6 @@ setup_info = dict( 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', - ], -) + ], ) setup(**setup_info) diff --git a/utils/avg_model.py b/utils/avg_model.py index 1fc00cb6..7c05ec78 100755 --- a/utils/avg_model.py +++ b/utils/avg_model.py @@ -25,8 +25,8 @@ def main(args): paddle.set_device('cpu') val_scores = [] - beat_val_scores = [] - selected_epochs = [] + beat_val_scores = None + selected_epochs = None jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') jsons = sorted(jsons, key=os.path.getmtime, reverse=True) @@ -80,9 +80,10 @@ def main(args): data = json.dumps({ "mode": 'val_best' if args.val_best else 'latest', "avg_ckpt": args.dst_model, - "ckpt": path_list, - "epoch": selected_epochs.tolist(), - "val_loss": beat_val_scores.tolist(), + "val_loss_mean": np.mean(beat_val_scores), + "ckpts": path_list, + "epochs": selected_epochs.tolist(), + "val_losses": beat_val_scores.tolist(), }) f.write(data + "\n")