From dfd80b3aa244d3037e887837721fed9080f024bf Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 23 Oct 2021 15:59:53 +0000 Subject: [PATCH] recog into decoders, format code --- deepspeech/__init__.py | 38 +- deepspeech/decoders/beam_search/__init__.py | 17 + .../decoders/beam_search/batch_beam_search.py | 17 + .../decoders/{ => beam_search}/beam_search.py | 236 ++++++----- deepspeech/decoders/recog.py | 134 +++--- deepspeech/decoders/recog_bin.py | 376 +++++++++++++++++ deepspeech/decoders/scorers/ngram.py | 5 +- deepspeech/decoders/utils.py | 7 +- deepspeech/exps/__init__.py | 49 +++ deepspeech/exps/u2_kaldi/bin/recog.py | 388 +----------------- deepspeech/exps/u2_kaldi/model.py | 5 +- .../frontend/featurizer/text_featurizer.py | 2 +- deepspeech/models/asr_interface.py | 39 +- deepspeech/models/u2/u2.py | 8 +- deepspeech/modules/decoder.py | 32 +- deepspeech/modules/mask.py | 2 +- deepspeech/training/cli.py | 2 +- deepspeech/training/trainer.py | 6 +- examples/librispeech/README.md | 6 +- examples/librispeech/s2/README.md | 14 +- .../librispeech/s2/conf/decode/decode.yaml | 2 +- examples/librispeech/s2/local/recog.sh | 24 +- examples/librispeech/s2/local/test.sh | 17 +- requirements.txt | 42 +- setup.py | 23 +- 25 files changed, 808 insertions(+), 683 deletions(-) create mode 100644 deepspeech/decoders/beam_search/__init__.py create mode 100644 deepspeech/decoders/beam_search/batch_beam_search.py rename deepspeech/decoders/{ => beam_search}/beam_search.py (74%) create mode 100644 deepspeech/decoders/recog_bin.py diff --git a/deepspeech/__init__.py b/deepspeech/__init__.py index ca5b4ba5..da3b1acb 100644 --- a/deepspeech/__init__.py +++ b/deepspeech/__init__.py @@ -233,7 +233,8 @@ def is_broadcastable(shp1, shp2): def masked_fill(xs: paddle.Tensor, mask: paddle.Tensor, value: Union[float, int]): - assert is_broadcastable(xs.shape, mask.shape) is True, (xs.shape, mask.shape) + assert is_broadcastable(xs.shape, mask.shape) is True, (xs.shape, + mask.shape) bshape = paddle.broadcast_shape(xs.shape, mask.shape) mask = mask.broadcast_to(bshape) trues = paddle.ones_like(xs) * value @@ -312,18 +313,18 @@ if not hasattr(paddle.Tensor, 'type_as'): def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: - assert len(args) == 1 - if isinstance(args[0], str): # dtype - return x.astype(args[0]) - elif isinstance(args[0], paddle.Tensor): # Tensor - return x.astype(args[0].dtype) - else: # Device - return x + assert len(args) == 1 + if isinstance(args[0], str): # dtype + return x.astype(args[0]) + elif isinstance(args[0], paddle.Tensor): # Tensor + return x.astype(args[0].dtype) + else: # Device + return x if not hasattr(paddle.Tensor, 'to'): - logger.debug("register user to to paddle.Tensor, remove this when fixed!") - setattr(paddle.Tensor, 'to', to) + logger.debug("register user to to paddle.Tensor, remove this when fixed!") + setattr(paddle.Tensor, 'to', to) def func_float(x: paddle.Tensor) -> paddle.Tensor: @@ -355,7 +356,6 @@ if not hasattr(paddle.Tensor, 'tolist'): setattr(paddle.Tensor, 'tolist', tolist) - ########### hcak paddle.nn.functional ############# # hack loss def ctc_loss(logits, @@ -384,7 +384,6 @@ logger.debug( ) F.ctc_loss = ctc_loss - ########### hcak paddle.nn ############# from paddle.nn import Layer from typing import Optional @@ -394,6 +393,7 @@ from typing import Tuple from typing import Iterator from collections import OrderedDict, abc as container_abcs + class LayerDict(paddle.nn.Layer): r"""Holds submodules in a dictionary. @@ -438,7 +438,7 @@ class LayerDict(paddle.nn.Layer): return x """ - def __init__(self, modules: Optional[Mapping[str, Layer]] = None) -> None: + def __init__(self, modules: Optional[Mapping[str, Layer]]=None) -> None: super(LayerDict, self).__init__() if modules is not None: self.update(modules) @@ -505,10 +505,11 @@ class LayerDict(paddle.nn.Layer): """ if not isinstance(modules, container_abcs.Iterable): raise TypeError("LayerDict.update should be called with an " - "iterable of key/value pairs, but got " + - type(modules).__name__) + "iterable of key/value pairs, but got " + type( + modules).__name__) - if isinstance(modules, (OrderedDict, LayerDict, container_abcs.Mapping)): + if isinstance(modules, + (OrderedDict, LayerDict, container_abcs.Mapping)): for key, module in modules.items(): self[key] = module else: @@ -520,14 +521,15 @@ class LayerDict(paddle.nn.Layer): type(m).__name__) if not len(m) == 2: raise ValueError("LayerDict update sequence element " - "#" + str(j) + " has length " + str(len(m)) + - "; 2 is required") + "#" + str(j) + " has length " + str( + len(m)) + "; 2 is required") # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] # that's too cumbersome to type correctly with overloads, so we add an ignore here self[m[0]] = m[1] # type: ignore[assignment] # remove forward alltogether to fallback on Module's _forward_unimplemented + if not hasattr(paddle.nn, 'LayerDict'): logger.debug( "register user LayerDict to paddle.nn, remove this when fixed!") diff --git a/deepspeech/decoders/beam_search/__init__.py b/deepspeech/decoders/beam_search/__init__.py new file mode 100644 index 00000000..79a1e9d3 --- /dev/null +++ b/deepspeech/decoders/beam_search/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .batch_beam_search import BatchBeamSearch +from .beam_search import beam_search +from .beam_search import BeamSearch +from .beam_search import Hypothesis diff --git a/deepspeech/decoders/beam_search/batch_beam_search.py b/deepspeech/decoders/beam_search/batch_beam_search.py new file mode 100644 index 00000000..3fc1c435 --- /dev/null +++ b/deepspeech/decoders/beam_search/batch_beam_search.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class BatchBeamSearch(): + pass diff --git a/deepspeech/decoders/beam_search.py b/deepspeech/decoders/beam_search/beam_search.py similarity index 74% rename from deepspeech/decoders/beam_search.py rename to deepspeech/decoders/beam_search/beam_search.py index afb8aefa..8fd8f9b8 100644 --- a/deepspeech/decoders/beam_search.py +++ b/deepspeech/decoders/beam_search/beam_search.py @@ -1,5 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Beam search module.""" - from itertools import chain from typing import Any from typing import Dict @@ -10,18 +22,18 @@ from typing import Union import paddle -from .utils import end_detect -from .scorers.scorer_interface import PartialScorerInterface -from .scorers.scorer_interface import ScorerInterface - +from ..scorers.scorer_interface import PartialScorerInterface +from ..scorers.scorer_interface import ScorerInterface +from ..utils import end_detect from deepspeech.utils.log import Log logger = Log(__name__).getlog() + class Hypothesis(NamedTuple): """Hypothesis data type.""" - yseq: paddle.Tensor # (T,) + yseq: paddle.Tensor # (T,) score: Union[float, paddle.Tensor] = 0 scores: Dict[str, Union[float, paddle.Tensor]] = dict() states: Dict[str, Any] = dict() @@ -31,25 +43,24 @@ class Hypothesis(NamedTuple): return self._replace( yseq=self.yseq.tolist(), score=float(self.score), - scores={k: float(v) for k, v in self.scores.items()}, - )._asdict() + scores={k: float(v) + for k, v in self.scores.items()}, )._asdict() class BeamSearch(paddle.nn.Layer): """Beam search implementation.""" def __init__( - self, - scorers: Dict[str, ScorerInterface], - weights: Dict[str, float], - beam_size: int, - vocab_size: int, - sos: int, - eos: int, - token_list: List[str] = None, - pre_beam_ratio: float = 1.5, - pre_beam_score_key: str = None, - ): + self, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + beam_size: int, + vocab_size: int, + sos: int, + eos: int, + token_list: List[str]=None, + pre_beam_ratio: float=1.5, + pre_beam_score_key: str=None, ): """Initialize beam search. Args: @@ -71,12 +82,12 @@ class BeamSearch(paddle.nn.Layer): super().__init__() # set scorers self.weights = weights - self.scorers = dict() # all = full + partial - self.full_scorers = dict() # full tokens - self.part_scorers = dict() # partial tokens + self.scorers = dict() # all = full + partial + self.full_scorers = dict() # full tokens + self.part_scorers = dict() # partial tokens # this module dict is required for recursive cast # `self.to(device, dtype)` in `recog.py` - self.nn_dict = paddle.nn.LayerDict() # nn.Layer + self.nn_dict = paddle.nn.LayerDict() # nn.Layer for k, v in scorers.items(): w = weights.get(k, 0) if w == 0 or v is None: @@ -100,20 +111,16 @@ class BeamSearch(paddle.nn.Layer): self.pre_beam_size = int(pre_beam_ratio * beam_size) self.beam_size = beam_size self.n_vocab = vocab_size - if ( - pre_beam_score_key is not None - and pre_beam_score_key != "full" - and pre_beam_score_key not in self.full_scorers - ): - raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") + if (pre_beam_score_key is not None and pre_beam_score_key != "full" and + pre_beam_score_key not in self.full_scorers): + raise KeyError( + f"{pre_beam_score_key} is not found in {self.full_scorers}") # selected `key` scorer to do pre beam search self.pre_beam_score_key = pre_beam_score_key # do_pre_beam when need, valid and has part_scorers - self.do_pre_beam = ( - self.pre_beam_score_key is not None - and self.pre_beam_size < self.n_vocab - and len(self.part_scorers) > 0 - ) + self.do_pre_beam = (self.pre_beam_score_key is not None and + self.pre_beam_size < self.n_vocab and + len(self.part_scorers) > 0) def init_hyp(self, x: paddle.Tensor) -> List[Hypothesis]: """Get an initial hypothesis data. @@ -135,12 +142,12 @@ class BeamSearch(paddle.nn.Layer): yseq=paddle.to_tensor([self.sos], place=x.place), score=0.0, scores=init_scores, - states=init_states, - ) + states=init_states, ) ] @staticmethod - def append_token(xs: paddle.Tensor, x: Union[int, paddle.Tensor]) -> paddle.Tensor: + def append_token(xs: paddle.Tensor, + x: Union[int, paddle.Tensor]) -> paddle.Tensor: """Append new token to prefix tokens. Args: @@ -154,9 +161,8 @@ class BeamSearch(paddle.nn.Layer): x = paddle.to_tensor([x], dtype=xs.dtype) if isinstance(x, int) else x return paddle.concat((xs, x)) - def score_full( - self, hyp: Hypothesis, x: paddle.Tensor - ) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: + def score_full(self, hyp: Hypothesis, x: paddle.Tensor + ) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: """Score new hypothesis by `self.full_scorers`. Args: @@ -178,9 +184,11 @@ class BeamSearch(paddle.nn.Layer): scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x) return scores, states - def score_partial( - self, hyp: Hypothesis, ids: paddle.Tensor, x: paddle.Tensor - ) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: + def score_partial(self, + hyp: Hypothesis, + ids: paddle.Tensor, + x: paddle.Tensor + ) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: """Score new hypothesis by `self.part_scorers`. Args: @@ -201,12 +209,12 @@ class BeamSearch(paddle.nn.Layer): states = dict() for k, d in self.part_scorers.items(): # scores[k] shape (len(ids),) - scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) + scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], + x) return scores, states - def beam( - self, weighted_scores: paddle.Tensor, ids: paddle.Tensor - ) -> Tuple[paddle.Tensor, paddle.Tensor]: + def beam(self, weighted_scores: paddle.Tensor, + ids: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute topk full token ids and partial token ids. Args: @@ -223,7 +231,8 @@ class BeamSearch(paddle.nn.Layer): """ # no pre beam performed, `ids` equal to `weighted_scores` if weighted_scores.size(0) == ids.size(0): - top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab + top_ids = weighted_scores.topk( + self.beam_size)[1] # index in n_vocab return top_ids, top_ids # mask pruned in pre-beam not to select in topk @@ -231,18 +240,18 @@ class BeamSearch(paddle.nn.Layer): weighted_scores[:] = -float("inf") weighted_scores[ids] = tmp # top_ids no equal to local_ids, since ids shape not same - top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab - local_ids = weighted_scores[ids].topk(self.beam_size)[1] # index in len(ids) + top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab + local_ids = weighted_scores[ids].topk( + self.beam_size)[1] # index in len(ids) return top_ids, local_ids @staticmethod def merge_scores( - prev_scores: Dict[str, float], - next_full_scores: Dict[str, paddle.Tensor], - full_idx: int, - next_part_scores: Dict[str, paddle.Tensor], - part_idx: int, - ) -> Dict[str, paddle.Tensor]: + prev_scores: Dict[str, float], + next_full_scores: Dict[str, paddle.Tensor], + full_idx: int, + next_part_scores: Dict[str, paddle.Tensor], + part_idx: int, ) -> Dict[str, paddle.Tensor]: """Merge scores for new hypothesis. Args: @@ -288,9 +297,8 @@ class BeamSearch(paddle.nn.Layer): new_states[k] = d.select_state(part_states[k], part_idx) return new_states - def search( - self, running_hyps: List[Hypothesis], x: paddle.Tensor - ) -> List[Hypothesis]: + def search(self, running_hyps: List[Hypothesis], + x: paddle.Tensor) -> List[Hypothesis]: """Search new tokens for running hypotheses and encoded speech x. Args: @@ -311,11 +319,9 @@ class BeamSearch(paddle.nn.Layer): weighted_scores += self.weights[k] * scores[k] # partial scoring if self.do_pre_beam: - pre_beam_scores = ( - weighted_scores - if self.pre_beam_score_key == "full" - else scores[self.pre_beam_score_key] - ) + pre_beam_scores = (weighted_scores + if self.pre_beam_score_key == "full" else + scores[self.pre_beam_score_key]) part_ids = paddle.topk(pre_beam_scores, self.pre_beam_size)[1] part_scores, part_states = self.score_partial(hyp, part_ids, x) for k in self.part_scorers: @@ -331,22 +337,21 @@ class BeamSearch(paddle.nn.Layer): Hypothesis( score=weighted_scores[j], yseq=self.append_token(hyp.yseq, j), - scores=self.merge_scores( - hyp.scores, scores, j, part_scores, part_j - ), + scores=self.merge_scores(hyp.scores, scores, j, + part_scores, part_j), states=self.merge_states(states, part_states, part_j), - ) - ) + )) # sort and prune 2 x beam -> beam - best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ - : min(len(best_hyps), self.beam_size) - ] + best_hyps = sorted( + best_hyps, key=lambda x: x.score, + reverse=True)[:min(len(best_hyps), self.beam_size)] return best_hyps - def forward( - self, x: paddle.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 - ) -> List[Hypothesis]: + def forward(self, + x: paddle.Tensor, + maxlenratio: float=0.0, + minlenratio: float=0.0) -> List[Hypothesis]: """Perform beam search. Args: @@ -381,9 +386,11 @@ class BeamSearch(paddle.nn.Layer): logger.debug("position " + str(i)) best = self.search(running_hyps, x) # post process of one iteration - running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) + running_hyps = self.post_process(i, maxlen, maxlenratio, best, + ended_hyps) # end detection - if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): + if maxlenratio == 0.0 and end_detect( + [h.asdict() for h in ended_hyps], i): logger.info(f"end detected at {i}") break if len(running_hyps) == 0: @@ -395,15 +402,10 @@ class BeamSearch(paddle.nn.Layer): nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) # check the number of hypotheses reaching to eos if len(nbest_hyps) == 0: - logger.warning( - "there is no N-best results, perform recognition " - "again with smaller minlenratio." - ) - return ( - [] - if minlenratio < 0.1 - else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) - ) + logger.warning("there is no N-best results, perform recognition " + "again with smaller minlenratio.") + return ([] if minlenratio < 0.1 else + self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))) # report the best result best = nbest_hyps[0] @@ -412,7 +414,9 @@ class BeamSearch(paddle.nn.Layer): f"{float(v):6.2f} * {self.weights[k]:3} = {float(v) * self.weights[k]:6.2f} for {k}" ) logger.info(f"total log probability: {float(best.score):.2f}") - logger.info(f"normalized log probability: {float(best.score) / len(best.yseq):.2f}") + logger.info( + f"normalized log probability: {float(best.score) / len(best.yseq):.2f}" + ) logger.info(f"total number of ended hypotheses: {len(nbest_hyps)}") if self.token_list is not None: # logger.info( @@ -420,21 +424,17 @@ class BeamSearch(paddle.nn.Layer): # + "".join([self.token_list[x] for x in best.yseq[1:-1]]) # + "\n" # ) - logger.info( - "best hypo: " - + "".join([self.token_list[x] for x in best.yseq[1:]]) - + "\n" - ) + logger.info("best hypo: " + "".join( + [self.token_list[x] for x in best.yseq[1:]]) + "\n") return nbest_hyps def post_process( - self, - i: int, - maxlen: int, - maxlenratio: float, - running_hyps: List[Hypothesis], - ended_hyps: List[Hypothesis], - ) -> List[Hypothesis]: + self, + i: int, + maxlen: int, + maxlenratio: float, + running_hyps: List[Hypothesis], + ended_hyps: List[Hypothesis], ) -> List[Hypothesis]: """Perform post-processing of beam search iterations. Args: @@ -450,10 +450,8 @@ class BeamSearch(paddle.nn.Layer): """ logger.debug(f"the number of running hypotheses: {len(running_hyps)}") if self.token_list is not None: - logger.debug( - "best hypo: " - + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) - ) + logger.debug("best hypo: " + "".join( + [self.token_list[x] for x in running_hyps[0].yseq[1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logger.info("adding in the last position in the loop") @@ -468,7 +466,8 @@ class BeamSearch(paddle.nn.Layer): for hyp in running_hyps: if hyp.yseq[-1] == self.eos: # e.g., Word LM needs to add final score - for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): + for k, d in chain(self.full_scorers.items(), + self.part_scorers.items()): s = d.final_score(hyp.states[k]) hyp.scores[k] += s hyp = hyp._replace(score=hyp.score + self.weights[k] * s) @@ -479,19 +478,18 @@ class BeamSearch(paddle.nn.Layer): def beam_search( - x: paddle.Tensor, - sos: int, - eos: int, - beam_size: int, - vocab_size: int, - scorers: Dict[str, ScorerInterface], - weights: Dict[str, float], - token_list: List[str] = None, - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - pre_beam_ratio: float = 1.5, - pre_beam_score_key: str = "full", -) -> list: + x: paddle.Tensor, + sos: int, + eos: int, + beam_size: int, + vocab_size: int, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + token_list: List[str]=None, + maxlenratio: float=0.0, + minlenratio: float=0.0, + pre_beam_ratio: float=1.5, + pre_beam_score_key: str="full", ) -> list: """Perform beam search with scorers. Args: @@ -527,6 +525,6 @@ def beam_search( pre_beam_score_key=pre_beam_score_key, sos=sos, eos=eos, - token_list=token_list, - ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio) + token_list=token_list, ).forward( + x=x, maxlenratio=maxlenratio, minlenratio=minlenratio) return [h.asdict() for h in ret] diff --git a/deepspeech/decoders/recog.py b/deepspeech/decoders/recog.py index 867569aa..c8df65d6 100644 --- a/deepspeech/decoders/recog.py +++ b/deepspeech/decoders/recog.py @@ -1,37 +1,57 @@ -"""V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`.""" - +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`.""" import json +from pathlib import Path + +import jsonlines import paddle import yaml from yacs.config import CfgNode -from pathlib import Path -import jsonlines -# from espnet.asr.asr_utils import get_model_conf -# from espnet.asr.asr_utils import torch_load -# from espnet.asr.pytorch_backend.asr import load_trained_model -# from espnet.nets.lm_interface import dynamic_import_lm - -from deepspeech.models.asr_interface import ASRInterface - -from .utils import add_results_to_json -# from .batch_beam_search import BatchBeamSearch +from .beam_search import BatchBeamSearch from .beam_search import BeamSearch -from .scorers.scorer_interface import BatchScorerInterface from .scorers.length_bonus import LengthBonus - +from .scorers.scorer_interface import BatchScorerInterface +from .utils import add_results_to_json +from deepspeech.exps import dynamic_import_tester from deepspeech.io.reader import LoadInputsAndTargets +from deepspeech.models.asr_interface import ASRInterface from deepspeech.utils.log import Log +# from espnet.asr.asr_utils import get_model_conf +# from espnet.asr.asr_utils import torch_load +# from espnet.nets.lm_interface import dynamic_import_lm + logger = Log(__name__).getlog() +# NOTE: you need this func to generate our sphinx doc -from deepspeech.utils.dynamic_import import dynamic_import -from deepspeech.utils.utility import print_arguments -model_test_alias = { - "u2": "deepspeech.exps.u2.model:U2Tester", - "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester", -} +def load_trained_model(args): + args.nprocs = args.ngpu + confs = CfgNode() + confs.set_new_allowed(True) + confs.merge_from_file(args.model_conf) + class_obj = dynamic_import_tester(args.model_name) + exp = class_obj(confs, args) + with exp.eval(): + exp.setup() + exp.restore() + char_list = exp.args.char_list + model = exp.model + return model, char_list, exp, confs + def recog_v2(args): """Decode with custom models that implements ScorerInterface. @@ -48,33 +68,17 @@ def recog_v2(args): raise NotImplementedError("streaming mode is not implemented") if args.word_rnnlm: raise NotImplementedError("word LM is not implemented") - args.nprocs = args.ngpu - # set_deterministic(args) - - #model, train_args = load_trained_model(args.model) - model_path = Path(args.model) - ckpt_dir = model_path.parent.parent - - confs = CfgNode() - confs.set_new_allowed(True) - confs.merge_from_file(args.model_conf) - - class_obj = dynamic_import(args.model_name, model_test_alias) - exp = class_obj(confs, args) - with exp.eval(): - exp.setup() - exp.restore() - char_list = exp.args.char_list - model = exp.model + # set_deterministic(args) + model, char_list, exp, confs = load_trained_model(args) assert isinstance(model, ASRInterface) + load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, preprocess_conf=confs.collator.augmentation_config - if args.preprocess_conf is None - else args.preprocess_conf, + if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, ) @@ -100,7 +104,7 @@ def recog_v2(args): else: ngram = None - scorers = model.scorers() + scorers = model.scorers() # decoder scorers["lm"] = lm scorers["ngram"] = ngram scorers["length_bonus"] = LengthBonus(len(char_list)) @@ -125,18 +129,15 @@ def recog_v2(args): # TODO(karita): make all scorers batchfied if args.batchsize == 1: non_batch = [ - k - for k, v in beam_search.full_scorers.items() + k for k, v in beam_search.full_scorers.items() if not isinstance(v, BatchScorerInterface) ] if len(non_batch) == 0: beam_search.__class__ = BatchBeamSearch logger.info("BatchBeamSearch implementation is selected.") else: - logger.warning( - f"As non-batch scorers {non_batch} are found, " - f"fall back to non-batch implementation." - ) + logger.warning(f"As non-batch scorers {non_batch} are found, " + f"fall back to non-batch implementation.") if args.ngpu > 1: raise NotImplementedError("only single GPU decoding is supported") @@ -157,7 +158,7 @@ def recog_v2(args): with jsonlines.open(args.recog_json, "r") as reader: for item in reader: js.append(item) - # josnlines to dict, key by 'utt' + # jsonlines to dict, key by 'utt', value by jsonline js = {item['utt']: item for item in js} new_js = {} @@ -169,25 +170,26 @@ def recog_v2(args): feat = load_inputs_and_targets(batch)[0][0] logger.info(f'feat: {feat.shape}') enc = model.encode(paddle.to_tensor(feat).to(dtype)) - logger.info(f'eouts: {enc.shape}') - nbest_hyps = beam_search( - x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio - ) + logger.info(f'eout: {enc.shape}') + nbest_hyps = beam_search(x=enc, + maxlenratio=args.maxlenratio, + minlenratio=args.minlenratio) nbest_hyps = [ - h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)] + h.asdict() + for h in nbest_hyps[:min(len(nbest_hyps), args.nbest)] ] - new_js[name] = add_results_to_json( - js[name], nbest_hyps, char_list - ) + new_js[name] = add_results_to_json(js[name], nbest_hyps, + char_list) - item = new_js[name]['output'][0] # 1-best - utt = name + item = new_js[name]['output'][0] # 1-best ref = item['text'] - rec_text = item['rec_text'].replace('▁', ' ').replace('', '').strip() - rec_tokenid = map(int, item['rec_tokenid'].split()) + rec_text = item['rec_text'].replace('▁', + ' ').replace('', + '').strip() + rec_tokenid = list(map(int, item['rec_tokenid'].split())) f.write({ - "utt": utt, - "refs": [ref], - "hyps": [rec_text], - "hyps_tokenid": [rec_tokenid], - }) \ No newline at end of file + "utt": name, + "refs": [ref], + "hyps": [rec_text], + "hyps_tokenid": [rec_tokenid], + }) diff --git a/deepspeech/decoders/recog_bin.py b/deepspeech/decoders/recog_bin.py new file mode 100644 index 00000000..567dfecd --- /dev/null +++ b/deepspeech/decoders/recog_bin.py @@ -0,0 +1,376 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""End-to-end speech recognition model decoding script.""" +import logging +import os +import random +import sys +from distutils.util import strtobool + +import configargparse +import numpy as np + +from .recog import recog_v2 + + +def get_parser(): + """Get default arguments.""" + parser = configargparse.ArgumentParser( + description="Transcribe text from speech using " + "a speech recognition model on one CPU or GPU", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, ) + parser.add( + '--model-name', + type=str, + default='u2_kaldi', + help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st') + # general configuration + parser.add("--config", is_config_file=True, help="Config file path") + parser.add( + "--config2", + is_config_file=True, + help="Second config file path that overwrites the settings in `--config`", + ) + parser.add( + "--config3", + is_config_file=True, + help="Third config file path that overwrites the settings " + "in `--config` and `--config2`", ) + + parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs") + parser.add_argument( + "--dtype", + choices=("float16", "float32", "float64"), + default="float32", + help="Float precision (only available in --api v2)", ) + parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") + parser.add_argument("--seed", type=int, default=1, help="Random seed") + parser.add_argument( + "--verbose", "-V", type=int, default=2, help="Verbose option") + parser.add_argument( + "--batchsize", + type=int, + default=1, + help="Batch size for beam search (0: means no batch processing)", ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", ) + parser.add_argument( + "--api", + default="v2", + choices=["v2"], + help="Beam search APIs " + "v2: Experimental API. It supports any models that implements ScorerInterface.", + ) + # task related + parser.add_argument( + "--recog-json", type=str, help="Filename of recognition data (json)") + parser.add_argument( + "--result-label", + type=str, + required=True, + help="Filename of result label data (json)", ) + # model (parameter) related + parser.add_argument( + "--model", + type=str, + required=True, + help="Model file parameters to read") + parser.add_argument( + "--model-conf", type=str, default=None, help="Model config file") + parser.add_argument( + "--num-spkrs", + type=int, + default=1, + choices=[1, 2], + help="Number of speakers in the speech", ) + parser.add_argument( + "--num-encs", + default=1, + type=int, + help="Number of encoders in the model.") + # search related + parser.add_argument( + "--nbest", type=int, default=1, help="Output N-best hypotheses") + parser.add_argument("--beam-size", type=int, default=1, help="Beam size") + parser.add_argument( + "--penalty", type=float, default=0.0, help="Incertion penalty") + parser.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="""Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths. + If maxlenratio<0.0, its absolute value is interpreted + as a constant max output length""", ) + parser.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", ) + parser.add_argument( + "--ctc-weight", + type=float, + default=0.0, + help="CTC weight in joint decoding") + parser.add_argument( + "--weights-ctc-dec", + type=float, + action="append", + help="ctc weight assigned to each encoder during decoding." + "[in multi-encoder mode only]", ) + parser.add_argument( + "--ctc-window-margin", + type=int, + default=0, + help="""Use CTC window with margin parameter to accelerate + CTC/attention decoding especially on GPU. Smaller magin + makes decoding faster, but may increase search errors. + If margin=0 (default), this function is disabled""", ) + # transducer related + parser.add_argument( + "--search-type", + type=str, + default="default", + choices=["default", "nsc", "tsd", "alsd", "maes"], + help="""Type of beam search implementation to use during inference. + Can be either: default beam search ("default"), + N-Step Constrained beam search ("nsc"), Time-Synchronous Decoding ("tsd"), + Alignment-Length Synchronous Decoding ("alsd") or + modified Adaptive Expansion Search ("maes").""", ) + parser.add_argument( + "--nstep", + type=int, + default=1, + help="""Number of expansion steps allowed in NSC beam search or mAES + (nstep > 0 for NSC and nstep > 1 for mAES).""", ) + parser.add_argument( + "--prefix-alpha", + type=int, + default=2, + help="Length prefix difference allowed in NSC beam search or mAES.", ) + parser.add_argument( + "--max-sym-exp", + type=int, + default=2, + help="Number of symbol expansions allowed in TSD.", ) + parser.add_argument( + "--u-max", + type=int, + default=400, + help="Length prefix difference allowed in ALSD.", ) + parser.add_argument( + "--expansion-gamma", + type=float, + default=2.3, + help="Allowed logp difference for prune-by-value method in mAES.", ) + parser.add_argument( + "--expansion-beta", + type=int, + default=2, + help="""Number of additional candidates for expanded hypotheses + selection in mAES.""", ) + parser.add_argument( + "--score-norm", + type=strtobool, + nargs="?", + default=True, + help="Normalize final hypotheses' score by length", ) + parser.add_argument( + "--softmax-temperature", + type=float, + default=1.0, + help="Penalization term for softmax function.", ) + # rnnlm related + parser.add_argument( + "--rnnlm", type=str, default=None, help="RNNLM model file to read") + parser.add_argument( + "--rnnlm-conf", + type=str, + default=None, + help="RNNLM model config file to read") + parser.add_argument( + "--word-rnnlm", + type=str, + default=None, + help="Word RNNLM model file to read") + parser.add_argument( + "--word-rnnlm-conf", + type=str, + default=None, + help="Word RNNLM model config file to read", ) + parser.add_argument( + "--word-dict", type=str, default=None, help="Word list to read") + parser.add_argument( + "--lm-weight", type=float, default=0.1, help="RNNLM weight") + # ngram related + parser.add_argument( + "--ngram-model", + type=str, + default=None, + help="ngram model file to read") + parser.add_argument( + "--ngram-weight", type=float, default=0.1, help="ngram weight") + parser.add_argument( + "--ngram-scorer", + type=str, + default="part", + choices=("full", "part"), + help="""if the ngram is set as a part scorer, similar with CTC scorer, + ngram scorer only scores topK hypethesis. + if the ngram is set as full scorer, ngram scorer scores all hypthesis + the decoding speed of part scorer is musch faster than full one""", + ) + # streaming related + parser.add_argument( + "--streaming-mode", + type=str, + default=None, + choices=["window", "segment"], + help="""Use streaming recognizer for inference. + `--batchsize` must be set to 0 to enable this mode""", ) + parser.add_argument( + "--streaming-window", type=int, default=10, help="Window size") + parser.add_argument( + "--streaming-min-blank-dur", + type=int, + default=10, + help="Minimum blank duration threshold", ) + parser.add_argument( + "--streaming-onset-margin", type=int, default=1, help="Onset margin") + parser.add_argument( + "--streaming-offset-margin", type=int, default=1, help="Offset margin") + # non-autoregressive related + # Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail. + parser.add_argument( + "--maskctc-n-iterations", + type=int, + default=10, + help="Number of decoding iterations." + "For Mask CTC, set 0 to predict 1 mask/iter.", ) + parser.add_argument( + "--maskctc-probability-threshold", + type=float, + default=0.999, + help="Threshold probability for CTC output", ) + # quantize model related + parser.add_argument( + "--quantize-config", + nargs="*", + help="Quantize config list. E.g.: --quantize-config=[Linear,LSTM,GRU]", + ) + parser.add_argument( + "--quantize-dtype", + type=str, + default="qint8", + help="Dtype dynamic quantize") + parser.add_argument( + "--quantize-asr-model", + type=bool, + default=False, + help="Quantize asr model", ) + parser.add_argument( + "--quantize-lm-model", + type=bool, + default=False, + help="Quantize lm model", ) + return parser + + +def main(args): + """Run the main decoding function.""" + parser = get_parser() + parser.add_argument( + "--output", metavar="CKPT_DIR", help="path to save checkpoint.") + parser.add_argument( + "--checkpoint_path", type=str, help="path to load checkpoint") + parser.add_argument("--dict-path", type=str, help="path to load checkpoint") + args = parser.parse_args(args) + + if args.ngpu == 0 and args.dtype == "float16": + raise ValueError( + f"--dtype {args.dtype} does not support the CPU backend.") + + # logging info + if args.verbose == 1: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose == 2: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + logging.info(args) + + # check CUDA_VISIBLE_DEVICES + if args.ngpu > 0: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is None: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + elif args.ngpu != len(cvd.split(",")): + logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") + sys.exit(1) + + # TODO(mn5k): support of multiple GPUs + if args.ngpu > 1: + logging.error("The program only supports ngpu=1.") + sys.exit(1) + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # seed setting + random.seed(args.seed) + np.random.seed(args.seed) + logging.info("set random seed = %d" % args.seed) + + # validate rnn options + if args.rnnlm is not None and args.word_rnnlm is not None: + logging.error( + "It seems that both --rnnlm and --word-rnnlm are specified. " + "Please use either option.") + sys.exit(1) + + # recog + if args.num_spkrs == 1: + if args.num_encs == 1: + # Experimental API that supports custom LMs + if args.api == "v2": + from deepspeech.decoders.recog import recog_v2 + recog_v2(args) + else: + raise ValueError("Only support --api v2") + else: + if args.api == "v2": + raise NotImplementedError( + f"--num-encs {args.num_encs} > 1 is not supported in --api v2" + ) + elif args.num_spkrs == 2: + raise ValueError("asr_mix not supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/deepspeech/decoders/scorers/ngram.py b/deepspeech/decoders/scorers/ngram.py index 050a8c81..a34d8248 100644 --- a/deepspeech/decoders/scorers/ngram.py +++ b/deepspeech/decoders/scorers/ngram.py @@ -85,8 +85,9 @@ class NgramFullScorer(Ngrambase, BatchScorerInterface): and next state list for ys. """ - return self.score_partial_( - y, paddle.to_tensor(range(self.charlen)), state, x) + return self.score_partial_(y, + paddle.to_tensor(range(self.charlen)), state, + x) class NgramPartScorer(Ngrambase, PartialScorerInterface): diff --git a/deepspeech/decoders/utils.py b/deepspeech/decoders/utils.py index f59b55d9..3ed9c5da 100644 --- a/deepspeech/decoders/utils.py +++ b/deepspeech/decoders/utils.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import numpy as np + from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -98,7 +98,8 @@ def add_results_to_json(js, nbest_hyps, char_list): for n, hyp in enumerate(nbest_hyps, 1): # parse hypothesis - rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list) + rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, + char_list) # copy ground-truth if len(js["output"]) > 0: @@ -125,4 +126,4 @@ def add_results_to_json(js, nbest_hyps, char_list): logger.info("groundtruth: %s" % out_dic["text"]) logger.info("prediction : %s" % out_dic["rec_text"]) - return new_js \ No newline at end of file + return new_js diff --git a/deepspeech/exps/__init__.py b/deepspeech/exps/__init__.py index 185a92b8..29953014 100644 --- a/deepspeech/exps/__init__.py +++ b/deepspeech/exps/__init__.py @@ -11,3 +11,52 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from deepspeech.training.trainer import Trainer +from deepspeech.utils.dynamic_import import dynamic_import + +model_trainer_alias = { + "ds2": "deepspeech.exp.deepspeech2.model:DeepSpeech2Trainer", + "u2": "deepspeech.exps.u2.model:U2Trainer", + "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Trainer", + "u2_st": "deepspeech.exps.u2_st.model:U2STTrainer", +} + + +def dynamic_import_trainer(module): + """Import Trainer dynamically. + + Args: + module (str): trainer name. e.g., ds2, u2, u2_kaldi + + Returns: + type: Trainer class + + """ + model_class = dynamic_import(module, model_trainer_alias) + assert issubclass(model_class, + Trainer), f"{module} does not implement Trainer" + return model_class + + +model_tester_alias = { + "ds2": "deepspeech.exp.deepspeech2.model:DeepSpeech2Tester", + "u2": "deepspeech.exps.u2.model:U2Tester", + "u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester", + "u2_st": "deepspeech.exps.u2_st.model:U2STTester", +} + + +def dynamic_import_tester(module): + """Import Tester dynamically. + + Args: + module (str): tester name. e.g., ds2, u2, u2_kaldi + + Returns: + type: Tester class + + """ + model_class = dynamic_import(module, model_tester_alias) + assert issubclass(model_class, + Trainer), f"{module} does not implement Tester" + return model_class diff --git a/deepspeech/exps/u2_kaldi/bin/recog.py b/deepspeech/exps/u2_kaldi/bin/recog.py index 60f4e58a..e94a1ab1 100644 --- a/deepspeech/exps/u2_kaldi/bin/recog.py +++ b/deepspeech/exps/u2_kaldi/bin/recog.py @@ -1,379 +1,19 @@ - -"""End-to-end speech recognition model decoding script.""" - -import configargparse -import logging -import os -import random +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import sys -import numpy as np - -from distutils.util import strtobool -from deepspeech.training.cli import default_argument_parser - -# NOTE: you need this func to generate our sphinx doc - -def get_parser(): - """Get default arguments.""" - parser = configargparse.ArgumentParser( - description="Transcribe text from speech using " - "a speech recognition model on one CPU or GPU", - config_file_parser_class=configargparse.YAMLConfigFileParser, - formatter_class=configargparse.ArgumentDefaultsHelpFormatter, - ) - parser.add( - '--model-name', - type=str, - default='u2_kaldi', - help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st') - # general configuration - parser.add("--config", is_config_file=True, help="Config file path") - parser.add( - "--config2", - is_config_file=True, - help="Second config file path that overwrites the settings in `--config`", - ) - parser.add( - "--config3", - is_config_file=True, - help="Third config file path that overwrites the settings " - "in `--config` and `--config2`", - ) - - parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs") - parser.add_argument( - "--dtype", - choices=("float16", "float32", "float64"), - default="float32", - help="Float precision (only available in --api v2)", - ) - parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") - parser.add_argument("--seed", type=int, default=1, help="Random seed") - parser.add_argument("--verbose", "-V", type=int, default=2, help="Verbose option") - parser.add_argument( - "--batchsize", - type=int, - default=1, - help="Batch size for beam search (0: means no batch processing)", - ) - parser.add_argument( - "--preprocess-conf", - type=str, - default=None, - help="The configuration file for the pre-processing", - ) - parser.add_argument( - "--api", - default="v2", - choices=["v2"], - help="Beam search APIs " - "v2: Experimental API. It supports any models that implements ScorerInterface.", - ) - # task related - parser.add_argument( - "--recog-json", type=str, help="Filename of recognition data (json)" - ) - parser.add_argument( - "--result-label", - type=str, - required=True, - help="Filename of result label data (json)", - ) - # model (parameter) related - parser.add_argument( - "--model", type=str, required=True, help="Model file parameters to read" - ) - parser.add_argument( - "--model-conf", type=str, default=None, help="Model config file" - ) - parser.add_argument( - "--num-spkrs", - type=int, - default=1, - choices=[1, 2], - help="Number of speakers in the speech", - ) - parser.add_argument( - "--num-encs", default=1, type=int, help="Number of encoders in the model." - ) - # search related - parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") - parser.add_argument("--beam-size", type=int, default=1, help="Beam size") - parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty") - parser.add_argument( - "--maxlenratio", - type=float, - default=0.0, - help="""Input length ratio to obtain max output length. - If maxlenratio=0.0 (default), it uses a end-detect function - to automatically find maximum hypothesis lengths. - If maxlenratio<0.0, its absolute value is interpreted - as a constant max output length""", - ) - parser.add_argument( - "--minlenratio", - type=float, - default=0.0, - help="Input length ratio to obtain min output length", - ) - parser.add_argument( - "--ctc-weight", type=float, default=0.0, help="CTC weight in joint decoding" - ) - parser.add_argument( - "--weights-ctc-dec", - type=float, - action="append", - help="ctc weight assigned to each encoder during decoding." - "[in multi-encoder mode only]", - ) - parser.add_argument( - "--ctc-window-margin", - type=int, - default=0, - help="""Use CTC window with margin parameter to accelerate - CTC/attention decoding especially on GPU. Smaller magin - makes decoding faster, but may increase search errors. - If margin=0 (default), this function is disabled""", - ) - # transducer related - parser.add_argument( - "--search-type", - type=str, - default="default", - choices=["default", "nsc", "tsd", "alsd", "maes"], - help="""Type of beam search implementation to use during inference. - Can be either: default beam search ("default"), - N-Step Constrained beam search ("nsc"), Time-Synchronous Decoding ("tsd"), - Alignment-Length Synchronous Decoding ("alsd") or - modified Adaptive Expansion Search ("maes").""", - ) - parser.add_argument( - "--nstep", - type=int, - default=1, - help="""Number of expansion steps allowed in NSC beam search or mAES - (nstep > 0 for NSC and nstep > 1 for mAES).""", - ) - parser.add_argument( - "--prefix-alpha", - type=int, - default=2, - help="Length prefix difference allowed in NSC beam search or mAES.", - ) - parser.add_argument( - "--max-sym-exp", - type=int, - default=2, - help="Number of symbol expansions allowed in TSD.", - ) - parser.add_argument( - "--u-max", - type=int, - default=400, - help="Length prefix difference allowed in ALSD.", - ) - parser.add_argument( - "--expansion-gamma", - type=float, - default=2.3, - help="Allowed logp difference for prune-by-value method in mAES.", - ) - parser.add_argument( - "--expansion-beta", - type=int, - default=2, - help="""Number of additional candidates for expanded hypotheses - selection in mAES.""", - ) - parser.add_argument( - "--score-norm", - type=strtobool, - nargs="?", - default=True, - help="Normalize final hypotheses' score by length", - ) - parser.add_argument( - "--softmax-temperature", - type=float, - default=1.0, - help="Penalization term for softmax function.", - ) - # rnnlm related - parser.add_argument( - "--rnnlm", type=str, default=None, help="RNNLM model file to read" - ) - parser.add_argument( - "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read" - ) - parser.add_argument( - "--word-rnnlm", type=str, default=None, help="Word RNNLM model file to read" - ) - parser.add_argument( - "--word-rnnlm-conf", - type=str, - default=None, - help="Word RNNLM model config file to read", - ) - parser.add_argument("--word-dict", type=str, default=None, help="Word list to read") - parser.add_argument("--lm-weight", type=float, default=0.1, help="RNNLM weight") - # ngram related - parser.add_argument( - "--ngram-model", type=str, default=None, help="ngram model file to read" - ) - parser.add_argument("--ngram-weight", type=float, default=0.1, help="ngram weight") - parser.add_argument( - "--ngram-scorer", - type=str, - default="part", - choices=("full", "part"), - help="""if the ngram is set as a part scorer, similar with CTC scorer, - ngram scorer only scores topK hypethesis. - if the ngram is set as full scorer, ngram scorer scores all hypthesis - the decoding speed of part scorer is musch faster than full one""", - ) - # streaming related - parser.add_argument( - "--streaming-mode", - type=str, - default=None, - choices=["window", "segment"], - help="""Use streaming recognizer for inference. - `--batchsize` must be set to 0 to enable this mode""", - ) - parser.add_argument("--streaming-window", type=int, default=10, help="Window size") - parser.add_argument( - "--streaming-min-blank-dur", - type=int, - default=10, - help="Minimum blank duration threshold", - ) - parser.add_argument( - "--streaming-onset-margin", type=int, default=1, help="Onset margin" - ) - parser.add_argument( - "--streaming-offset-margin", type=int, default=1, help="Offset margin" - ) - # non-autoregressive related - # Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail. - parser.add_argument( - "--maskctc-n-iterations", - type=int, - default=10, - help="Number of decoding iterations." - "For Mask CTC, set 0 to predict 1 mask/iter.", - ) - parser.add_argument( - "--maskctc-probability-threshold", - type=float, - default=0.999, - help="Threshold probability for CTC output", - ) - # quantize model related - parser.add_argument( - "--quantize-config", - nargs="*", - help="Quantize config list. E.g.: --quantize-config=[Linear,LSTM,GRU]", - ) - parser.add_argument( - "--quantize-dtype", type=str, default="qint8", help="Dtype dynamic quantize" - ) - parser.add_argument( - "--quantize-asr-model", - type=bool, - default=False, - help="Quantize asr model", - ) - parser.add_argument( - "--quantize-lm-model", - type=bool, - default=False, - help="Quantize lm model", - ) - return parser - - -def main(args): - """Run the main decoding function.""" - parser = get_parser() - parser.add_argument( - "--output", metavar="CKPT_DIR", help="path to save checkpoint.") - parser.add_argument( - "--checkpoint_path", type=str, help="path to load checkpoint") - parser.add_argument( - "--dict-path", type=str, help="path to load checkpoint") - # parser = default_argument_parser(parser) - args = parser.parse_args(args) - - if args.ngpu == 0 and args.dtype == "float16": - raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.") - - # logging info - if args.verbose == 1: - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - elif args.verbose == 2: - logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - else: - logging.basicConfig( - level=logging.WARN, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) - logging.warning("Skip DEBUG/INFO messages") - logging.info(args) - - # check CUDA_VISIBLE_DEVICES - if args.ngpu > 0: - cvd = os.environ.get("CUDA_VISIBLE_DEVICES") - if cvd is None: - logging.warning("CUDA_VISIBLE_DEVICES is not set.") - elif args.ngpu != len(cvd.split(",")): - logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") - sys.exit(1) - - # TODO(mn5k): support of multiple GPUs - if args.ngpu > 1: - logging.error("The program only supports ngpu=1.") - sys.exit(1) - - # display PYTHONPATH - logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) - - # seed setting - random.seed(args.seed) - np.random.seed(args.seed) - logging.info("set random seed = %d" % args.seed) - - # validate rnn options - if args.rnnlm is not None and args.word_rnnlm is not None: - logging.error( - "It seems that both --rnnlm and --word-rnnlm are specified. " - "Please use either option." - ) - sys.exit(1) - - # recog - if args.num_spkrs == 1: - if args.num_encs == 1: - # Experimental API that supports custom LMs - if args.api == "v2": - from deepspeech.decoders.recog import recog_v2 - recog_v2(args) - else: - raise ValueError("Only support --api v2") - else: - if args.api == "v2": - raise NotImplementedError( - f"--num-encs {args.num_encs} > 1 is not supported in --api v2" - ) - elif args.num_spkrs == 2: - raise ValueError("asr_mix not supported.") - +from deepspeech.decoders.recog_bin import main if __name__ == "__main__": main(sys.argv[1:]) diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 3aadca85..f8624326 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -434,8 +434,9 @@ class U2Tester(U2Trainer): simulate_streaming=cfg.simulate_streaming) decode_time = time.time() - start_time - for i, (utt, target, result, rec_tids) in enumerate(zip( - utts, target_transcripts, result_transcripts, result_tokenids)): + for i, (utt, target, result, rec_tids) in enumerate( + zip(utts, target_transcripts, result_transcripts, + result_tokenids)): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref diff --git a/deepspeech/frontend/featurizer/text_featurizer.py b/deepspeech/frontend/featurizer/text_featurizer.py index 18ff411b..a6834ebc 100644 --- a/deepspeech/frontend/featurizer/text_featurizer.py +++ b/deepspeech/frontend/featurizer/text_featurizer.py @@ -140,7 +140,7 @@ class TextFeaturizer(): Returns: str: text string. """ - tokens = [t.replace(SPACE, " ") for t in tokens ] + tokens = [t.replace(SPACE, " ") for t in tokens] return "".join(tokens) def word_tokenize(self, text): diff --git a/deepspeech/models/asr_interface.py b/deepspeech/models/asr_interface.py index eb820fc0..7dac81b4 100644 --- a/deepspeech/models/asr_interface.py +++ b/deepspeech/models/asr_interface.py @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ASR Interface module.""" import argparse @@ -72,7 +85,8 @@ class ASRInterface: :return: attention weights (B, Lmax, Tmax) :rtype: float ndarray """ - raise NotImplementedError("calculate_all_attentions method is not implemented") + raise NotImplementedError( + "calculate_all_attentions method is not implemented") def calculate_all_ctc_probs(self, xs, ilens, ys): """Calculate CTC probability. @@ -83,7 +97,8 @@ class ASRInterface: :return: CTC probabilities (B, Tmax, vocab) :rtype: float ndarray """ - raise NotImplementedError("calculate_all_ctc_probs method is not implemented") + raise NotImplementedError( + "calculate_all_ctc_probs method is not implemented") @property def attention_plot_class(self): @@ -102,8 +117,7 @@ class ASRInterface: def get_total_subsampling_factor(self): """Get total subsampling factor.""" raise NotImplementedError( - "get_total_subsampling_factor method is not implemented" - ) + "get_total_subsampling_factor method is not implemented") def encode(self, feat): """Encode feature in `beam_search` (optional). @@ -126,23 +140,22 @@ class ASRInterface: predefined_asr = { - "transformer": "deepspeech.models.u2:E2E", - "conformer": "deepspeech.models.u2:E2E", + "transformer": "deepspeech.models.u2:U2Model", + "conformer": "deepspeech.models.u2:U2Model", } -def dynamic_import_asr(module, name): + +def dynamic_import_asr(module): """Import ASR models dynamically. Args: - module (str): module_name:class_name or alias in `predefined_asr` - name (str): asr name. e.g., transformer, conformer + module (str): asr name. e.g., transformer, conformer Returns: type: ASR class """ - model_class = dynamic_import(module, predefined_asr.get(name, "")) - assert issubclass( - model_class, ASRInterface - ), f"{module} does not implement ASRInterface" + model_class = dynamic_import(module, predefined_asr) + assert issubclass(model_class, + ASRInterface), f"{module} does not implement ASRInterface" return model_class diff --git a/deepspeech/models/u2/u2.py b/deepspeech/models/u2/u2.py index fa517906..6cd3b775 100644 --- a/deepspeech/models/u2/u2.py +++ b/deepspeech/models/u2/u2.py @@ -28,8 +28,10 @@ from paddle import jit from paddle import nn from yacs.config import CfgNode +from deepspeech.decoders.scorers.ctc import CTCPrefixScorer from deepspeech.frontend.utility import IGNORE_ID from deepspeech.frontend.utility import load_cmvn +from deepspeech.models.asr_interface import ASRInterface from deepspeech.modules.cmvn import GlobalCMVN from deepspeech.modules.ctc import CTCDecoder from deepspeech.modules.decoder import TransformerDecoder @@ -49,8 +51,6 @@ from deepspeech.utils.tensor_utils import pad_sequence from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.utility import log_add from deepspeech.utils.utility import UpdateConfig -from deepspeech.models.asr_interface import ASRInterface -from deepspeech.decoders.scorers.ctc import CTCPrefixScorer __all__ = ["U2Model", "U2InferModel"] @@ -816,10 +816,10 @@ class U2BaseModel(ASRInterface, nn.Layer): class U2DecodeModel(U2BaseModel): - def scorers(self): """Scorers.""" - return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) + return dict( + decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x): """Encode acoustic features. diff --git a/deepspeech/modules/decoder.py b/deepspeech/modules/decoder.py index ee3572ea..735f06dc 100644 --- a/deepspeech/modules/decoder.py +++ b/deepspeech/modules/decoder.py @@ -12,23 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. """Decoder definition.""" +from typing import Any from typing import List from typing import Optional from typing import Tuple -from typing import Any import paddle from paddle import nn from typeguard import check_argument_types +from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface from deepspeech.modules.attention import MultiHeadedAttention from deepspeech.modules.decoder_layer import DecoderLayer from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.mask import make_non_pad_mask -from deepspeech.modules.mask import subsequent_mask from deepspeech.modules.mask import make_xs_mask +from deepspeech.modules.mask import subsequent_mask from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward -from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -191,8 +191,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ys: (ylen,) x: (xlen, n_feat) """ - ys_mask = subsequent_mask(len(ys)).unsqueeze(0) # (B,L,L) - x_mask = make_xs_mask(x.unsqueeze(0)).unsqueeze(1) # (B,1,T) + ys_mask = subsequent_mask(len(ys)).unsqueeze(0) # (B,L,L) + x_mask = make_xs_mask(x.unsqueeze(0)).unsqueeze(1) # (B,1,T) if self.selfattention_layer_type != "selfattn": # TODO(karita): implement cache logging.warning( @@ -200,16 +200,14 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ) state = None logp, state = self.forward_one_step( - x.unsqueeze(0), x_mask, - ys.unsqueeze(0), ys_mask, - cache=state - ) + x.unsqueeze(0), x_mask, ys.unsqueeze(0), ys_mask, cache=state) return logp.squeeze(0), state # batch beam search API (see BatchScorerInterface) - def batch_score( - self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor - ) -> Tuple[paddle.Tensor, List[Any]]: + def batch_score(self, + ys: paddle.Tensor, + states: List[Any], + xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]: """Score new token batch (required). Args: @@ -237,10 +235,12 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): ] # batch decoding - ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) # (B,L,L) - xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T) - logp, states = self.forward_one_step(xs, xs_mask, ys, ys_mask, cache=batch_state) + ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) # (B,L,L) + xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T) + logp, states = self.forward_one_step( + xs, xs_mask, ys, ys_mask, cache=batch_state) # transpose state of [layer, batch] into [batch, layer] - state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] + state_list = [[states[i][b] for i in range(n_layers)] + for b in range(n_batch)] return logp, state_list diff --git a/deepspeech/modules/mask.py b/deepspeech/modules/mask.py index 7ae418c0..52f8e4bc 100644 --- a/deepspeech/modules/mask.py +++ b/deepspeech/modules/mask.py @@ -24,7 +24,7 @@ __all__ = [ ] -def make_xs_mask(xs:paddle.Tensor, pad_value=0.0) -> paddle.Tensor: +def make_xs_mask(xs: paddle.Tensor, pad_value=0.0) -> paddle.Tensor: """Maks mask tensor containing indices of non-padded part. Args: xs (paddle.Tensor): (B, T, D), zeros for pad. diff --git a/deepspeech/training/cli.py b/deepspeech/training/cli.py index db7076d3..14a34cb7 100644 --- a/deepspeech/training/cli.py +++ b/deepspeech/training/cli.py @@ -64,7 +64,7 @@ def default_argument_parser(parser=None): """ if parser is None: parser = argparse.ArgumentParser() - + parser.register('action', 'extend', ExtendAction) parser.add_argument( '--conf', type=open, action=LoadFromFile, help="config file.") diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index b2ee7a1b..2c238920 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -126,7 +126,8 @@ class Trainer(): logger.info(f"Set seed {args.seed}") # profiler and benchmark options - if hasattr(self.args, "benchmark_batch_size") and self.args.benchmark_batch_size: + if hasattr(self.args, + "benchmark_batch_size") and self.args.benchmark_batch_size: with UpdateConfig(self.config): self.config.collator.batch_size = self.args.benchmark_batch_size self.config.training.log_interval = 1 @@ -335,8 +336,7 @@ class Trainer(): """ assert self.args.checkpoint_path infos = self.checkpoint.load_latest_parameters( - self.model, - checkpoint_path=self.args.checkpoint_path) + self.model, checkpoint_path=self.args.checkpoint_path) return infos def run_test(self): diff --git a/examples/librispeech/README.md b/examples/librispeech/README.md index 72459095..5943cf1d 100644 --- a/examples/librispeech/README.md +++ b/examples/librispeech/README.md @@ -1,8 +1,8 @@ # ASR -* s0 is for deepspeech2 offline -* s1 is for transformer/conformer/U2 -* s2 is for transformer/conformer/U2 w/ kaldi feat, need install Kaldi +* s0 is for deepspeech2 offline +* s1 is for transformer/conformer/U2 +* s2 is for transformer/conformer/U2 w/ kaldi feat, need install Kaldi ## Data | Data Subset | Duration in Seconds | diff --git a/examples/librispeech/s2/README.md b/examples/librispeech/s2/README.md index 6e41cc37..d5df37d8 100644 --- a/examples/librispeech/s2/README.md +++ b/examples/librispeech/s2/README.md @@ -1,14 +1,14 @@ # LibriSpeech | Model | Params | Config | Augmentation| Loss | -| --- | --- | --- | --- | +| --- | --- | --- | --- | | transformer | 32.52 M | conf/transformer.yaml | spec_aug | 6.3197922706604 | -| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err | +| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| test-clean | attention | 2620 | 52576 | 96.4 | 2.5 | 1.1 | 0.4 | 4.0 | 34.7 | -| test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 | -| test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 | -| test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 | -| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 | +| test-clean | attention | 2620 | 52576 | 96.4 | 2.5 | 1.1 | 0.4 | 4.0 | 34.7 | +| test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 | +| test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 | +| test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 | +| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 | diff --git a/examples/librispeech/s2/conf/decode/decode.yaml b/examples/librispeech/s2/conf/decode/decode.yaml index 867bf611..4c702db5 100644 --- a/examples/librispeech/s2/conf/decode/decode.yaml +++ b/examples/librispeech/s2/conf/decode/decode.yaml @@ -1,6 +1,6 @@ batchsize: 0 beam-size: 60 -ctc-weight: 0.4 +ctc-weight: 0.0 lm-weight: 0.0 maxlenratio: 0.0 minlenratio: 0.0 diff --git a/examples/librispeech/s2/local/recog.sh b/examples/librispeech/s2/local/recog.sh index 0393535b..df3846c0 100755 --- a/examples/librispeech/s2/local/recog.sh +++ b/examples/librispeech/s2/local/recog.sh @@ -5,11 +5,14 @@ set -e expdir=exp datadir=data nj=32 +tag= +# decode config decode_config=conf/decode/decode.yaml + +# lm params lang_model=rnnlm.model.best lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/ - lmtag='nolm' recog_set="test-clean test-other dev-clean dev-other" @@ -21,18 +24,21 @@ bpemode=unigram bpeprefix="data/bpe_${bpemode}_${nbpe}" bpemodel=${bpeprefix}.model -if [ $# != 3 ];then - echo "usage: ${0} config_path dict_path ckpt_path_prefix" - exit -1 +# bin params +config_path=conf/transformer.yaml +dict=data/bpe_unigram_5000_units.txt +ckpt_prefix= + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +if [ -z ${ckpt_prefix} ]; then + echo "usage: $0 --ckpt_prefix ckpt_prefix" + exit 1 fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -config_path=$1 -dict=$2 -ckpt_prefix=$3 - ckpt_dir=$(dirname `dirname ${ckpt_prefix}`) echo "ckpt dir: ${ckpt_dir}" @@ -61,7 +67,7 @@ for dmethd in join_ctc; do for rtask in ${recog_set}; do ( echo "${rtask} dataset" - decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag} + decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag}_${tag} feat_recog_dir=${datadir} mkdir -p ${decode_dir} mkdir -p ${feat_recog_dir} diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh index 006a13c5..5f662d29 100755 --- a/examples/librispeech/s2/local/test.sh +++ b/examples/librispeech/s2/local/test.sh @@ -17,19 +17,20 @@ bpemode=unigram bpeprefix="data/bpe_${bpemode}_${nbpe}" bpemodel=${bpeprefix}.model -if [ $# != 3 ];then - echo "usage: ${0} config_path dict_path ckpt_path_prefix" - exit -1 +config_path=conf/transformer.yaml +dict=data/bpe_unigram_5000_units.txt +ckpt_prefix= + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +if [ -z ${ckpt_prefix} ]; then + echo "usage: $0 --ckpt_prefix ckpt_prefix" + exit 1 fi ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." -config_path=$1 -dict=$2 -ckpt_prefix=$3 - - ckpt_dir=$(dirname `dirname ${ckpt_prefix}`) echo "ckpt dir: ${ckpt_dir}" diff --git a/requirements.txt b/requirements.txt index 42fd645a..a7310a02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,43 +1,43 @@ +ConfigArgParse coverage editdistance +g2p_en +g2pM gpustat +h5py +inflect +jieba jsonlines kaldiio +librosa +llvmlite loguru +matplotlib +nltk +numba +numpy==1.20.0 +pandas +phkit Pillow +praatio~=4.1 pre-commit pybind11 +pypinyin +pyworld resampy==0.2.2 sacrebleu scipy==1.2.1 sentencepiece snakeviz +soundfile~=0.10 sox tensorboardX textgrid +timer tqdm typeguard -visualdl==2.2.0 -yacs -numpy==1.20.0 -numba -nltk -inflect -librosa unidecode -llvmlite -matplotlib -pandas -soundfile~=0.10 -g2p_en -pypinyin +visualdl==2.2.0 webrtcvad -g2pM -praatio~=4.1 -h5py -timer -pyworld -jieba -phkit +yacs yq -ConfigArgParse \ No newline at end of file diff --git a/setup.py b/setup.py index 07b6aac0..bd982129 100644 --- a/setup.py +++ b/setup.py @@ -11,20 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import inspect import io import os import re +import subprocess as sp import sys from pathlib import Path -import contextlib -import inspect +from setuptools import Command from setuptools import find_packages from setuptools import setup -from setuptools import Command from setuptools.command.develop import develop from setuptools.command.install import install -import subprocess as sp HERE = Path(os.path.abspath(os.path.dirname(__file__))) @@ -40,16 +40,18 @@ def pushd(new_dir): def read(*names, **kwargs): - with io.open(os.path.join(os.path.dirname(__file__), *names), - encoding=kwargs.get("encoding", "utf8")) as fp: + with io.open( + os.path.join(os.path.dirname(__file__), *names), + encoding=kwargs.get("encoding", "utf8")) as fp: return fp.read() def check_call(cmd: str, shell=False, executable=None): try: - sp.check_call(cmd.split(), - shell=shell, - executable="/bin/bash" if shell else executable) + sp.check_call( + cmd.split(), + shell=shell, + executable="/bin/bash" if shell else executable) except sp.CalledProcessError as e: print( f"{__file__}:{inspect.currentframe().f_lineno}: CMD: {cmd}, Error:", @@ -189,7 +191,6 @@ setup_info = dict( 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', - ], -) + ], ) setup(**setup_info)