recog into decoders, format code

pull/928/head
Hui Zhang 3 years ago
parent ee6446a3aa
commit dfd80b3aa2

@ -233,7 +233,8 @@ def is_broadcastable(shp1, shp2):
def masked_fill(xs: paddle.Tensor, def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
value: Union[float, int]): value: Union[float, int]):
assert is_broadcastable(xs.shape, mask.shape) is True, (xs.shape, mask.shape) assert is_broadcastable(xs.shape, mask.shape) is True, (xs.shape,
mask.shape)
bshape = paddle.broadcast_shape(xs.shape, mask.shape) bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape) mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value trues = paddle.ones_like(xs) * value
@ -312,18 +313,18 @@ if not hasattr(paddle.Tensor, 'type_as'):
def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
assert len(args) == 1 assert len(args) == 1
if isinstance(args[0], str): # dtype if isinstance(args[0], str): # dtype
return x.astype(args[0]) return x.astype(args[0])
elif isinstance(args[0], paddle.Tensor): # Tensor elif isinstance(args[0], paddle.Tensor): # Tensor
return x.astype(args[0].dtype) return x.astype(args[0].dtype)
else: # Device else: # Device
return x return x
if not hasattr(paddle.Tensor, 'to'): if not hasattr(paddle.Tensor, 'to'):
logger.debug("register user to to paddle.Tensor, remove this when fixed!") logger.debug("register user to to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'to', to) setattr(paddle.Tensor, 'to', to)
def func_float(x: paddle.Tensor) -> paddle.Tensor: def func_float(x: paddle.Tensor) -> paddle.Tensor:
@ -355,7 +356,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
setattr(paddle.Tensor, 'tolist', tolist) setattr(paddle.Tensor, 'tolist', tolist)
########### hcak paddle.nn.functional ############# ########### hcak paddle.nn.functional #############
# hack loss # hack loss
def ctc_loss(logits, def ctc_loss(logits,
@ -384,7 +384,6 @@ logger.debug(
) )
F.ctc_loss = ctc_loss F.ctc_loss = ctc_loss
########### hcak paddle.nn ############# ########### hcak paddle.nn #############
from paddle.nn import Layer from paddle.nn import Layer
from typing import Optional from typing import Optional
@ -394,6 +393,7 @@ from typing import Tuple
from typing import Iterator from typing import Iterator
from collections import OrderedDict, abc as container_abcs from collections import OrderedDict, abc as container_abcs
class LayerDict(paddle.nn.Layer): class LayerDict(paddle.nn.Layer):
r"""Holds submodules in a dictionary. r"""Holds submodules in a dictionary.
@ -438,7 +438,7 @@ class LayerDict(paddle.nn.Layer):
return x return x
""" """
def __init__(self, modules: Optional[Mapping[str, Layer]] = None) -> None: def __init__(self, modules: Optional[Mapping[str, Layer]]=None) -> None:
super(LayerDict, self).__init__() super(LayerDict, self).__init__()
if modules is not None: if modules is not None:
self.update(modules) self.update(modules)
@ -505,10 +505,11 @@ class LayerDict(paddle.nn.Layer):
""" """
if not isinstance(modules, container_abcs.Iterable): if not isinstance(modules, container_abcs.Iterable):
raise TypeError("LayerDict.update should be called with an " raise TypeError("LayerDict.update should be called with an "
"iterable of key/value pairs, but got " + "iterable of key/value pairs, but got " + type(
type(modules).__name__) modules).__name__)
if isinstance(modules, (OrderedDict, LayerDict, container_abcs.Mapping)): if isinstance(modules,
(OrderedDict, LayerDict, container_abcs.Mapping)):
for key, module in modules.items(): for key, module in modules.items():
self[key] = module self[key] = module
else: else:
@ -520,14 +521,15 @@ class LayerDict(paddle.nn.Layer):
type(m).__name__) type(m).__name__)
if not len(m) == 2: if not len(m) == 2:
raise ValueError("LayerDict update sequence element " raise ValueError("LayerDict update sequence element "
"#" + str(j) + " has length " + str(len(m)) + "#" + str(j) + " has length " + str(
"; 2 is required") len(m)) + "; 2 is required")
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
# that's too cumbersome to type correctly with overloads, so we add an ignore here # that's too cumbersome to type correctly with overloads, so we add an ignore here
self[m[0]] = m[1] # type: ignore[assignment] self[m[0]] = m[1] # type: ignore[assignment]
# remove forward alltogether to fallback on Module's _forward_unimplemented # remove forward alltogether to fallback on Module's _forward_unimplemented
if not hasattr(paddle.nn, 'LayerDict'): if not hasattr(paddle.nn, 'LayerDict'):
logger.debug( logger.debug(
"register user LayerDict to paddle.nn, remove this when fixed!") "register user LayerDict to paddle.nn, remove this when fixed!")

@ -0,0 +1,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .batch_beam_search import BatchBeamSearch
from .beam_search import beam_search
from .beam_search import BeamSearch
from .beam_search import Hypothesis

@ -0,0 +1,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class BatchBeamSearch():
pass

@ -1,5 +1,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Beam search module.""" """Beam search module."""
from itertools import chain from itertools import chain
from typing import Any from typing import Any
from typing import Dict from typing import Dict
@ -10,18 +22,18 @@ from typing import Union
import paddle import paddle
from .utils import end_detect from ..scorers.scorer_interface import PartialScorerInterface
from .scorers.scorer_interface import PartialScorerInterface from ..scorers.scorer_interface import ScorerInterface
from .scorers.scorer_interface import ScorerInterface from ..utils import end_detect
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
class Hypothesis(NamedTuple): class Hypothesis(NamedTuple):
"""Hypothesis data type.""" """Hypothesis data type."""
yseq: paddle.Tensor # (T,) yseq: paddle.Tensor # (T,)
score: Union[float, paddle.Tensor] = 0 score: Union[float, paddle.Tensor] = 0
scores: Dict[str, Union[float, paddle.Tensor]] = dict() scores: Dict[str, Union[float, paddle.Tensor]] = dict()
states: Dict[str, Any] = dict() states: Dict[str, Any] = dict()
@ -31,25 +43,24 @@ class Hypothesis(NamedTuple):
return self._replace( return self._replace(
yseq=self.yseq.tolist(), yseq=self.yseq.tolist(),
score=float(self.score), score=float(self.score),
scores={k: float(v) for k, v in self.scores.items()}, scores={k: float(v)
)._asdict() for k, v in self.scores.items()}, )._asdict()
class BeamSearch(paddle.nn.Layer): class BeamSearch(paddle.nn.Layer):
"""Beam search implementation.""" """Beam search implementation."""
def __init__( def __init__(
self, self,
scorers: Dict[str, ScorerInterface], scorers: Dict[str, ScorerInterface],
weights: Dict[str, float], weights: Dict[str, float],
beam_size: int, beam_size: int,
vocab_size: int, vocab_size: int,
sos: int, sos: int,
eos: int, eos: int,
token_list: List[str] = None, token_list: List[str]=None,
pre_beam_ratio: float = 1.5, pre_beam_ratio: float=1.5,
pre_beam_score_key: str = None, pre_beam_score_key: str=None, ):
):
"""Initialize beam search. """Initialize beam search.
Args: Args:
@ -71,12 +82,12 @@ class BeamSearch(paddle.nn.Layer):
super().__init__() super().__init__()
# set scorers # set scorers
self.weights = weights self.weights = weights
self.scorers = dict() # all = full + partial self.scorers = dict() # all = full + partial
self.full_scorers = dict() # full tokens self.full_scorers = dict() # full tokens
self.part_scorers = dict() # partial tokens self.part_scorers = dict() # partial tokens
# this module dict is required for recursive cast # this module dict is required for recursive cast
# `self.to(device, dtype)` in `recog.py` # `self.to(device, dtype)` in `recog.py`
self.nn_dict = paddle.nn.LayerDict() # nn.Layer self.nn_dict = paddle.nn.LayerDict() # nn.Layer
for k, v in scorers.items(): for k, v in scorers.items():
w = weights.get(k, 0) w = weights.get(k, 0)
if w == 0 or v is None: if w == 0 or v is None:
@ -100,20 +111,16 @@ class BeamSearch(paddle.nn.Layer):
self.pre_beam_size = int(pre_beam_ratio * beam_size) self.pre_beam_size = int(pre_beam_ratio * beam_size)
self.beam_size = beam_size self.beam_size = beam_size
self.n_vocab = vocab_size self.n_vocab = vocab_size
if ( if (pre_beam_score_key is not None and pre_beam_score_key != "full" and
pre_beam_score_key is not None pre_beam_score_key not in self.full_scorers):
and pre_beam_score_key != "full" raise KeyError(
and pre_beam_score_key not in self.full_scorers f"{pre_beam_score_key} is not found in {self.full_scorers}")
):
raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
# selected `key` scorer to do pre beam search # selected `key` scorer to do pre beam search
self.pre_beam_score_key = pre_beam_score_key self.pre_beam_score_key = pre_beam_score_key
# do_pre_beam when need, valid and has part_scorers # do_pre_beam when need, valid and has part_scorers
self.do_pre_beam = ( self.do_pre_beam = (self.pre_beam_score_key is not None and
self.pre_beam_score_key is not None self.pre_beam_size < self.n_vocab and
and self.pre_beam_size < self.n_vocab len(self.part_scorers) > 0)
and len(self.part_scorers) > 0
)
def init_hyp(self, x: paddle.Tensor) -> List[Hypothesis]: def init_hyp(self, x: paddle.Tensor) -> List[Hypothesis]:
"""Get an initial hypothesis data. """Get an initial hypothesis data.
@ -135,12 +142,12 @@ class BeamSearch(paddle.nn.Layer):
yseq=paddle.to_tensor([self.sos], place=x.place), yseq=paddle.to_tensor([self.sos], place=x.place),
score=0.0, score=0.0,
scores=init_scores, scores=init_scores,
states=init_states, states=init_states, )
)
] ]
@staticmethod @staticmethod
def append_token(xs: paddle.Tensor, x: Union[int, paddle.Tensor]) -> paddle.Tensor: def append_token(xs: paddle.Tensor,
x: Union[int, paddle.Tensor]) -> paddle.Tensor:
"""Append new token to prefix tokens. """Append new token to prefix tokens.
Args: Args:
@ -154,9 +161,8 @@ class BeamSearch(paddle.nn.Layer):
x = paddle.to_tensor([x], dtype=xs.dtype) if isinstance(x, int) else x x = paddle.to_tensor([x], dtype=xs.dtype) if isinstance(x, int) else x
return paddle.concat((xs, x)) return paddle.concat((xs, x))
def score_full( def score_full(self, hyp: Hypothesis, x: paddle.Tensor
self, hyp: Hypothesis, x: paddle.Tensor ) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]:
) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.full_scorers`. """Score new hypothesis by `self.full_scorers`.
Args: Args:
@ -178,9 +184,11 @@ class BeamSearch(paddle.nn.Layer):
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x) scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
return scores, states return scores, states
def score_partial( def score_partial(self,
self, hyp: Hypothesis, ids: paddle.Tensor, x: paddle.Tensor hyp: Hypothesis,
) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: ids: paddle.Tensor,
x: paddle.Tensor
) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.part_scorers`. """Score new hypothesis by `self.part_scorers`.
Args: Args:
@ -201,12 +209,12 @@ class BeamSearch(paddle.nn.Layer):
states = dict() states = dict()
for k, d in self.part_scorers.items(): for k, d in self.part_scorers.items():
# scores[k] shape (len(ids),) # scores[k] shape (len(ids),)
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k],
x)
return scores, states return scores, states
def beam( def beam(self, weighted_scores: paddle.Tensor,
self, weighted_scores: paddle.Tensor, ids: paddle.Tensor ids: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute topk full token ids and partial token ids. """Compute topk full token ids and partial token ids.
Args: Args:
@ -223,7 +231,8 @@ class BeamSearch(paddle.nn.Layer):
""" """
# no pre beam performed, `ids` equal to `weighted_scores` # no pre beam performed, `ids` equal to `weighted_scores`
if weighted_scores.size(0) == ids.size(0): if weighted_scores.size(0) == ids.size(0):
top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab top_ids = weighted_scores.topk(
self.beam_size)[1] # index in n_vocab
return top_ids, top_ids return top_ids, top_ids
# mask pruned in pre-beam not to select in topk # mask pruned in pre-beam not to select in topk
@ -231,18 +240,18 @@ class BeamSearch(paddle.nn.Layer):
weighted_scores[:] = -float("inf") weighted_scores[:] = -float("inf")
weighted_scores[ids] = tmp weighted_scores[ids] = tmp
# top_ids no equal to local_ids, since ids shape not same # top_ids no equal to local_ids, since ids shape not same
top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab
local_ids = weighted_scores[ids].topk(self.beam_size)[1] # index in len(ids) local_ids = weighted_scores[ids].topk(
self.beam_size)[1] # index in len(ids)
return top_ids, local_ids return top_ids, local_ids
@staticmethod @staticmethod
def merge_scores( def merge_scores(
prev_scores: Dict[str, float], prev_scores: Dict[str, float],
next_full_scores: Dict[str, paddle.Tensor], next_full_scores: Dict[str, paddle.Tensor],
full_idx: int, full_idx: int,
next_part_scores: Dict[str, paddle.Tensor], next_part_scores: Dict[str, paddle.Tensor],
part_idx: int, part_idx: int, ) -> Dict[str, paddle.Tensor]:
) -> Dict[str, paddle.Tensor]:
"""Merge scores for new hypothesis. """Merge scores for new hypothesis.
Args: Args:
@ -288,9 +297,8 @@ class BeamSearch(paddle.nn.Layer):
new_states[k] = d.select_state(part_states[k], part_idx) new_states[k] = d.select_state(part_states[k], part_idx)
return new_states return new_states
def search( def search(self, running_hyps: List[Hypothesis],
self, running_hyps: List[Hypothesis], x: paddle.Tensor x: paddle.Tensor) -> List[Hypothesis]:
) -> List[Hypothesis]:
"""Search new tokens for running hypotheses and encoded speech x. """Search new tokens for running hypotheses and encoded speech x.
Args: Args:
@ -311,11 +319,9 @@ class BeamSearch(paddle.nn.Layer):
weighted_scores += self.weights[k] * scores[k] weighted_scores += self.weights[k] * scores[k]
# partial scoring # partial scoring
if self.do_pre_beam: if self.do_pre_beam:
pre_beam_scores = ( pre_beam_scores = (weighted_scores
weighted_scores if self.pre_beam_score_key == "full" else
if self.pre_beam_score_key == "full" scores[self.pre_beam_score_key])
else scores[self.pre_beam_score_key]
)
part_ids = paddle.topk(pre_beam_scores, self.pre_beam_size)[1] part_ids = paddle.topk(pre_beam_scores, self.pre_beam_size)[1]
part_scores, part_states = self.score_partial(hyp, part_ids, x) part_scores, part_states = self.score_partial(hyp, part_ids, x)
for k in self.part_scorers: for k in self.part_scorers:
@ -331,22 +337,21 @@ class BeamSearch(paddle.nn.Layer):
Hypothesis( Hypothesis(
score=weighted_scores[j], score=weighted_scores[j],
yseq=self.append_token(hyp.yseq, j), yseq=self.append_token(hyp.yseq, j),
scores=self.merge_scores( scores=self.merge_scores(hyp.scores, scores, j,
hyp.scores, scores, j, part_scores, part_j part_scores, part_j),
),
states=self.merge_states(states, part_states, part_j), states=self.merge_states(states, part_states, part_j),
) ))
)
# sort and prune 2 x beam -> beam # sort and prune 2 x beam -> beam
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ best_hyps = sorted(
: min(len(best_hyps), self.beam_size) best_hyps, key=lambda x: x.score,
] reverse=True)[:min(len(best_hyps), self.beam_size)]
return best_hyps return best_hyps
def forward( def forward(self,
self, x: paddle.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 x: paddle.Tensor,
) -> List[Hypothesis]: maxlenratio: float=0.0,
minlenratio: float=0.0) -> List[Hypothesis]:
"""Perform beam search. """Perform beam search.
Args: Args:
@ -381,9 +386,11 @@ class BeamSearch(paddle.nn.Layer):
logger.debug("position " + str(i)) logger.debug("position " + str(i))
best = self.search(running_hyps, x) best = self.search(running_hyps, x)
# post process of one iteration # post process of one iteration
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) running_hyps = self.post_process(i, maxlen, maxlenratio, best,
ended_hyps)
# end detection # end detection
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): if maxlenratio == 0.0 and end_detect(
[h.asdict() for h in ended_hyps], i):
logger.info(f"end detected at {i}") logger.info(f"end detected at {i}")
break break
if len(running_hyps) == 0: if len(running_hyps) == 0:
@ -395,15 +402,10 @@ class BeamSearch(paddle.nn.Layer):
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
# check the number of hypotheses reaching to eos # check the number of hypotheses reaching to eos
if len(nbest_hyps) == 0: if len(nbest_hyps) == 0:
logger.warning( logger.warning("there is no N-best results, perform recognition "
"there is no N-best results, perform recognition " "again with smaller minlenratio.")
"again with smaller minlenratio." return ([] if minlenratio < 0.1 else
) self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)))
return (
[]
if minlenratio < 0.1
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
)
# report the best result # report the best result
best = nbest_hyps[0] best = nbest_hyps[0]
@ -412,7 +414,9 @@ class BeamSearch(paddle.nn.Layer):
f"{float(v):6.2f} * {self.weights[k]:3} = {float(v) * self.weights[k]:6.2f} for {k}" f"{float(v):6.2f} * {self.weights[k]:3} = {float(v) * self.weights[k]:6.2f} for {k}"
) )
logger.info(f"total log probability: {float(best.score):.2f}") logger.info(f"total log probability: {float(best.score):.2f}")
logger.info(f"normalized log probability: {float(best.score) / len(best.yseq):.2f}") logger.info(
f"normalized log probability: {float(best.score) / len(best.yseq):.2f}"
)
logger.info(f"total number of ended hypotheses: {len(nbest_hyps)}") logger.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
if self.token_list is not None: if self.token_list is not None:
# logger.info( # logger.info(
@ -420,21 +424,17 @@ class BeamSearch(paddle.nn.Layer):
# + "".join([self.token_list[x] for x in best.yseq[1:-1]]) # + "".join([self.token_list[x] for x in best.yseq[1:-1]])
# + "\n" # + "\n"
# ) # )
logger.info( logger.info("best hypo: " + "".join(
"best hypo: " [self.token_list[x] for x in best.yseq[1:]]) + "\n")
+ "".join([self.token_list[x] for x in best.yseq[1:]])
+ "\n"
)
return nbest_hyps return nbest_hyps
def post_process( def post_process(
self, self,
i: int, i: int,
maxlen: int, maxlen: int,
maxlenratio: float, maxlenratio: float,
running_hyps: List[Hypothesis], running_hyps: List[Hypothesis],
ended_hyps: List[Hypothesis], ended_hyps: List[Hypothesis], ) -> List[Hypothesis]:
) -> List[Hypothesis]:
"""Perform post-processing of beam search iterations. """Perform post-processing of beam search iterations.
Args: Args:
@ -450,10 +450,8 @@ class BeamSearch(paddle.nn.Layer):
""" """
logger.debug(f"the number of running hypotheses: {len(running_hyps)}") logger.debug(f"the number of running hypotheses: {len(running_hyps)}")
if self.token_list is not None: if self.token_list is not None:
logger.debug( logger.debug("best hypo: " + "".join(
"best hypo: " [self.token_list[x] for x in running_hyps[0].yseq[1:]]))
+ "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
)
# add eos in the final loop to avoid that there are no ended hyps # add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1: if i == maxlen - 1:
logger.info("adding <eos> in the last position in the loop") logger.info("adding <eos> in the last position in the loop")
@ -468,7 +466,8 @@ class BeamSearch(paddle.nn.Layer):
for hyp in running_hyps: for hyp in running_hyps:
if hyp.yseq[-1] == self.eos: if hyp.yseq[-1] == self.eos:
# e.g., Word LM needs to add final <eos> score # e.g., Word LM needs to add final <eos> score
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): for k, d in chain(self.full_scorers.items(),
self.part_scorers.items()):
s = d.final_score(hyp.states[k]) s = d.final_score(hyp.states[k])
hyp.scores[k] += s hyp.scores[k] += s
hyp = hyp._replace(score=hyp.score + self.weights[k] * s) hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
@ -479,19 +478,18 @@ class BeamSearch(paddle.nn.Layer):
def beam_search( def beam_search(
x: paddle.Tensor, x: paddle.Tensor,
sos: int, sos: int,
eos: int, eos: int,
beam_size: int, beam_size: int,
vocab_size: int, vocab_size: int,
scorers: Dict[str, ScorerInterface], scorers: Dict[str, ScorerInterface],
weights: Dict[str, float], weights: Dict[str, float],
token_list: List[str] = None, token_list: List[str]=None,
maxlenratio: float = 0.0, maxlenratio: float=0.0,
minlenratio: float = 0.0, minlenratio: float=0.0,
pre_beam_ratio: float = 1.5, pre_beam_ratio: float=1.5,
pre_beam_score_key: str = "full", pre_beam_score_key: str="full", ) -> list:
) -> list:
"""Perform beam search with scorers. """Perform beam search with scorers.
Args: Args:
@ -527,6 +525,6 @@ def beam_search(
pre_beam_score_key=pre_beam_score_key, pre_beam_score_key=pre_beam_score_key,
sos=sos, sos=sos,
eos=eos, eos=eos,
token_list=token_list, token_list=token_list, ).forward(
).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio) x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
return [h.asdict() for h in ret] return [h.asdict() for h in ret]

@ -1,37 +1,57 @@
"""V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`.""" # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
import json import json
from pathlib import Path
import jsonlines
import paddle import paddle
import yaml import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
from pathlib import Path
import jsonlines
# from espnet.asr.asr_utils import get_model_conf from .beam_search import BatchBeamSearch
# from espnet.asr.asr_utils import torch_load
# from espnet.asr.pytorch_backend.asr import load_trained_model
# from espnet.nets.lm_interface import dynamic_import_lm
from deepspeech.models.asr_interface import ASRInterface
from .utils import add_results_to_json
# from .batch_beam_search import BatchBeamSearch
from .beam_search import BeamSearch from .beam_search import BeamSearch
from .scorers.scorer_interface import BatchScorerInterface
from .scorers.length_bonus import LengthBonus from .scorers.length_bonus import LengthBonus
from .scorers.scorer_interface import BatchScorerInterface
from .utils import add_results_to_json
from deepspeech.exps import dynamic_import_tester
from deepspeech.io.reader import LoadInputsAndTargets from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.models.asr_interface import ASRInterface
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
# from espnet.asr.asr_utils import get_model_conf
# from espnet.asr.asr_utils import torch_load
# from espnet.nets.lm_interface import dynamic_import_lm
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
# NOTE: you need this func to generate our sphinx doc
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.utility import print_arguments
model_test_alias = { def load_trained_model(args):
"u2": "deepspeech.exps.u2.model:U2Tester", args.nprocs = args.ngpu
"u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester", confs = CfgNode()
} confs.set_new_allowed(True)
confs.merge_from_file(args.model_conf)
class_obj = dynamic_import_tester(args.model_name)
exp = class_obj(confs, args)
with exp.eval():
exp.setup()
exp.restore()
char_list = exp.args.char_list
model = exp.model
return model, char_list, exp, confs
def recog_v2(args): def recog_v2(args):
"""Decode with custom models that implements ScorerInterface. """Decode with custom models that implements ScorerInterface.
@ -48,33 +68,17 @@ def recog_v2(args):
raise NotImplementedError("streaming mode is not implemented") raise NotImplementedError("streaming mode is not implemented")
if args.word_rnnlm: if args.word_rnnlm:
raise NotImplementedError("word LM is not implemented") raise NotImplementedError("word LM is not implemented")
args.nprocs = args.ngpu
# set_deterministic(args)
#model, train_args = load_trained_model(args.model)
model_path = Path(args.model)
ckpt_dir = model_path.parent.parent
confs = CfgNode()
confs.set_new_allowed(True)
confs.merge_from_file(args.model_conf)
class_obj = dynamic_import(args.model_name, model_test_alias)
exp = class_obj(confs, args)
with exp.eval():
exp.setup()
exp.restore()
char_list = exp.args.char_list
model = exp.model # set_deterministic(args)
model, char_list, exp, confs = load_trained_model(args)
assert isinstance(model, ASRInterface) assert isinstance(model, ASRInterface)
load_inputs_and_targets = LoadInputsAndTargets( load_inputs_and_targets = LoadInputsAndTargets(
mode="asr", mode="asr",
load_output=False, load_output=False,
sort_in_input_length=False, sort_in_input_length=False,
preprocess_conf=confs.collator.augmentation_config preprocess_conf=confs.collator.augmentation_config
if args.preprocess_conf is None if args.preprocess_conf is None else args.preprocess_conf,
else args.preprocess_conf,
preprocess_args={"train": False}, preprocess_args={"train": False},
) )
@ -100,7 +104,7 @@ def recog_v2(args):
else: else:
ngram = None ngram = None
scorers = model.scorers() scorers = model.scorers() # decoder
scorers["lm"] = lm scorers["lm"] = lm
scorers["ngram"] = ngram scorers["ngram"] = ngram
scorers["length_bonus"] = LengthBonus(len(char_list)) scorers["length_bonus"] = LengthBonus(len(char_list))
@ -125,18 +129,15 @@ def recog_v2(args):
# TODO(karita): make all scorers batchfied # TODO(karita): make all scorers batchfied
if args.batchsize == 1: if args.batchsize == 1:
non_batch = [ non_batch = [
k k for k, v in beam_search.full_scorers.items()
for k, v in beam_search.full_scorers.items()
if not isinstance(v, BatchScorerInterface) if not isinstance(v, BatchScorerInterface)
] ]
if len(non_batch) == 0: if len(non_batch) == 0:
beam_search.__class__ = BatchBeamSearch beam_search.__class__ = BatchBeamSearch
logger.info("BatchBeamSearch implementation is selected.") logger.info("BatchBeamSearch implementation is selected.")
else: else:
logger.warning( logger.warning(f"As non-batch scorers {non_batch} are found, "
f"As non-batch scorers {non_batch} are found, " f"fall back to non-batch implementation.")
f"fall back to non-batch implementation."
)
if args.ngpu > 1: if args.ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported") raise NotImplementedError("only single GPU decoding is supported")
@ -157,7 +158,7 @@ def recog_v2(args):
with jsonlines.open(args.recog_json, "r") as reader: with jsonlines.open(args.recog_json, "r") as reader:
for item in reader: for item in reader:
js.append(item) js.append(item)
# josnlines to dict, key by 'utt' # jsonlines to dict, key by 'utt', value by jsonline
js = {item['utt']: item for item in js} js = {item['utt']: item for item in js}
new_js = {} new_js = {}
@ -169,25 +170,26 @@ def recog_v2(args):
feat = load_inputs_and_targets(batch)[0][0] feat = load_inputs_and_targets(batch)[0][0]
logger.info(f'feat: {feat.shape}') logger.info(f'feat: {feat.shape}')
enc = model.encode(paddle.to_tensor(feat).to(dtype)) enc = model.encode(paddle.to_tensor(feat).to(dtype))
logger.info(f'eouts: {enc.shape}') logger.info(f'eout: {enc.shape}')
nbest_hyps = beam_search( nbest_hyps = beam_search(x=enc,
x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio maxlenratio=args.maxlenratio,
) minlenratio=args.minlenratio)
nbest_hyps = [ nbest_hyps = [
h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)] h.asdict()
for h in nbest_hyps[:min(len(nbest_hyps), args.nbest)]
] ]
new_js[name] = add_results_to_json( new_js[name] = add_results_to_json(js[name], nbest_hyps,
js[name], nbest_hyps, char_list char_list)
)
item = new_js[name]['output'][0] # 1-best item = new_js[name]['output'][0] # 1-best
utt = name
ref = item['text'] ref = item['text']
rec_text = item['rec_text'].replace('', ' ').replace('<eos>', '').strip() rec_text = item['rec_text'].replace('',
rec_tokenid = map(int, item['rec_tokenid'].split()) ' ').replace('<eos>',
'').strip()
rec_tokenid = list(map(int, item['rec_tokenid'].split()))
f.write({ f.write({
"utt": utt, "utt": name,
"refs": [ref], "refs": [ref],
"hyps": [rec_text], "hyps": [rec_text],
"hyps_tokenid": [rec_tokenid], "hyps_tokenid": [rec_tokenid],
}) })

@ -0,0 +1,376 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end speech recognition model decoding script."""
import logging
import os
import random
import sys
from distutils.util import strtobool
import configargparse
import numpy as np
from .recog import recog_v2
def get_parser():
"""Get default arguments."""
parser = configargparse.ArgumentParser(
description="Transcribe text from speech using "
"a speech recognition model on one CPU or GPU",
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter, )
parser.add(
'--model-name',
type=str,
default='u2_kaldi',
help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
# general configuration
parser.add("--config", is_config_file=True, help="Config file path")
parser.add(
"--config2",
is_config_file=True,
help="Second config file path that overwrites the settings in `--config`",
)
parser.add(
"--config3",
is_config_file=True,
help="Third config file path that overwrites the settings "
"in `--config` and `--config2`", )
parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
parser.add_argument(
"--dtype",
choices=("float16", "float32", "float64"),
default="float32",
help="Float precision (only available in --api v2)", )
parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
parser.add_argument("--seed", type=int, default=1, help="Random seed")
parser.add_argument(
"--verbose", "-V", type=int, default=2, help="Verbose option")
parser.add_argument(
"--batchsize",
type=int,
default=1,
help="Batch size for beam search (0: means no batch processing)", )
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing", )
parser.add_argument(
"--api",
default="v2",
choices=["v2"],
help="Beam search APIs "
"v2: Experimental API. It supports any models that implements ScorerInterface.",
)
# task related
parser.add_argument(
"--recog-json", type=str, help="Filename of recognition data (json)")
parser.add_argument(
"--result-label",
type=str,
required=True,
help="Filename of result label data (json)", )
# model (parameter) related
parser.add_argument(
"--model",
type=str,
required=True,
help="Model file parameters to read")
parser.add_argument(
"--model-conf", type=str, default=None, help="Model config file")
parser.add_argument(
"--num-spkrs",
type=int,
default=1,
choices=[1, 2],
help="Number of speakers in the speech", )
parser.add_argument(
"--num-encs",
default=1,
type=int,
help="Number of encoders in the model.")
# search related
parser.add_argument(
"--nbest", type=int, default=1, help="Output N-best hypotheses")
parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
parser.add_argument(
"--penalty", type=float, default=0.0, help="Incertion penalty")
parser.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="""Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths.
If maxlenratio<0.0, its absolute value is interpreted
as a constant max output length""", )
parser.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length", )
parser.add_argument(
"--ctc-weight",
type=float,
default=0.0,
help="CTC weight in joint decoding")
parser.add_argument(
"--weights-ctc-dec",
type=float,
action="append",
help="ctc weight assigned to each encoder during decoding."
"[in multi-encoder mode only]", )
parser.add_argument(
"--ctc-window-margin",
type=int,
default=0,
help="""Use CTC window with margin parameter to accelerate
CTC/attention decoding especially on GPU. Smaller magin
makes decoding faster, but may increase search errors.
If margin=0 (default), this function is disabled""", )
# transducer related
parser.add_argument(
"--search-type",
type=str,
default="default",
choices=["default", "nsc", "tsd", "alsd", "maes"],
help="""Type of beam search implementation to use during inference.
Can be either: default beam search ("default"),
N-Step Constrained beam search ("nsc"), Time-Synchronous Decoding ("tsd"),
Alignment-Length Synchronous Decoding ("alsd") or
modified Adaptive Expansion Search ("maes").""", )
parser.add_argument(
"--nstep",
type=int,
default=1,
help="""Number of expansion steps allowed in NSC beam search or mAES
(nstep > 0 for NSC and nstep > 1 for mAES).""", )
parser.add_argument(
"--prefix-alpha",
type=int,
default=2,
help="Length prefix difference allowed in NSC beam search or mAES.", )
parser.add_argument(
"--max-sym-exp",
type=int,
default=2,
help="Number of symbol expansions allowed in TSD.", )
parser.add_argument(
"--u-max",
type=int,
default=400,
help="Length prefix difference allowed in ALSD.", )
parser.add_argument(
"--expansion-gamma",
type=float,
default=2.3,
help="Allowed logp difference for prune-by-value method in mAES.", )
parser.add_argument(
"--expansion-beta",
type=int,
default=2,
help="""Number of additional candidates for expanded hypotheses
selection in mAES.""", )
parser.add_argument(
"--score-norm",
type=strtobool,
nargs="?",
default=True,
help="Normalize final hypotheses' score by length", )
parser.add_argument(
"--softmax-temperature",
type=float,
default=1.0,
help="Penalization term for softmax function.", )
# rnnlm related
parser.add_argument(
"--rnnlm", type=str, default=None, help="RNNLM model file to read")
parser.add_argument(
"--rnnlm-conf",
type=str,
default=None,
help="RNNLM model config file to read")
parser.add_argument(
"--word-rnnlm",
type=str,
default=None,
help="Word RNNLM model file to read")
parser.add_argument(
"--word-rnnlm-conf",
type=str,
default=None,
help="Word RNNLM model config file to read", )
parser.add_argument(
"--word-dict", type=str, default=None, help="Word list to read")
parser.add_argument(
"--lm-weight", type=float, default=0.1, help="RNNLM weight")
# ngram related
parser.add_argument(
"--ngram-model",
type=str,
default=None,
help="ngram model file to read")
parser.add_argument(
"--ngram-weight", type=float, default=0.1, help="ngram weight")
parser.add_argument(
"--ngram-scorer",
type=str,
default="part",
choices=("full", "part"),
help="""if the ngram is set as a part scorer, similar with CTC scorer,
ngram scorer only scores topK hypethesis.
if the ngram is set as full scorer, ngram scorer scores all hypthesis
the decoding speed of part scorer is musch faster than full one""",
)
# streaming related
parser.add_argument(
"--streaming-mode",
type=str,
default=None,
choices=["window", "segment"],
help="""Use streaming recognizer for inference.
`--batchsize` must be set to 0 to enable this mode""", )
parser.add_argument(
"--streaming-window", type=int, default=10, help="Window size")
parser.add_argument(
"--streaming-min-blank-dur",
type=int,
default=10,
help="Minimum blank duration threshold", )
parser.add_argument(
"--streaming-onset-margin", type=int, default=1, help="Onset margin")
parser.add_argument(
"--streaming-offset-margin", type=int, default=1, help="Offset margin")
# non-autoregressive related
# Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail.
parser.add_argument(
"--maskctc-n-iterations",
type=int,
default=10,
help="Number of decoding iterations."
"For Mask CTC, set 0 to predict 1 mask/iter.", )
parser.add_argument(
"--maskctc-probability-threshold",
type=float,
default=0.999,
help="Threshold probability for CTC output", )
# quantize model related
parser.add_argument(
"--quantize-config",
nargs="*",
help="Quantize config list. E.g.: --quantize-config=[Linear,LSTM,GRU]",
)
parser.add_argument(
"--quantize-dtype",
type=str,
default="qint8",
help="Dtype dynamic quantize")
parser.add_argument(
"--quantize-asr-model",
type=bool,
default=False,
help="Quantize asr model", )
parser.add_argument(
"--quantize-lm-model",
type=bool,
default=False,
help="Quantize lm model", )
return parser
def main(args):
"""Run the main decoding function."""
parser = get_parser()
parser.add_argument(
"--output", metavar="CKPT_DIR", help="path to save checkpoint.")
parser.add_argument(
"--checkpoint_path", type=str, help="path to load checkpoint")
parser.add_argument("--dict-path", type=str, help="path to load checkpoint")
args = parser.parse_args(args)
if args.ngpu == 0 and args.dtype == "float16":
raise ValueError(
f"--dtype {args.dtype} does not support the CPU backend.")
# logging info
if args.verbose == 1:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
elif args.verbose == 2:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.warning("Skip DEBUG/INFO messages")
logging.info(args)
# check CUDA_VISIBLE_DEVICES
if args.ngpu > 0:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu != len(cvd.split(",")):
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
sys.exit(1)
# TODO(mn5k): support of multiple GPUs
if args.ngpu > 1:
logging.error("The program only supports ngpu=1.")
sys.exit(1)
# display PYTHONPATH
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
# seed setting
random.seed(args.seed)
np.random.seed(args.seed)
logging.info("set random seed = %d" % args.seed)
# validate rnn options
if args.rnnlm is not None and args.word_rnnlm is not None:
logging.error(
"It seems that both --rnnlm and --word-rnnlm are specified. "
"Please use either option.")
sys.exit(1)
# recog
if args.num_spkrs == 1:
if args.num_encs == 1:
# Experimental API that supports custom LMs
if args.api == "v2":
from deepspeech.decoders.recog import recog_v2
recog_v2(args)
else:
raise ValueError("Only support --api v2")
else:
if args.api == "v2":
raise NotImplementedError(
f"--num-encs {args.num_encs} > 1 is not supported in --api v2"
)
elif args.num_spkrs == 2:
raise ValueError("asr_mix not supported.")
if __name__ == "__main__":
main(sys.argv[1:])

@ -85,8 +85,9 @@ class NgramFullScorer(Ngrambase, BatchScorerInterface):
and next state list for ys. and next state list for ys.
""" """
return self.score_partial_( return self.score_partial_(y,
y, paddle.to_tensor(range(self.charlen)), state, x) paddle.to_tensor(range(self.charlen)), state,
x)
class NgramPartScorer(Ngrambase, PartialScorerInterface): class NgramPartScorer(Ngrambase, PartialScorerInterface):

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -98,7 +98,8 @@ def add_results_to_json(js, nbest_hyps, char_list):
for n, hyp in enumerate(nbest_hyps, 1): for n, hyp in enumerate(nbest_hyps, 1):
# parse hypothesis # parse hypothesis
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list) rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp,
char_list)
# copy ground-truth # copy ground-truth
if len(js["output"]) > 0: if len(js["output"]) > 0:
@ -125,4 +126,4 @@ def add_results_to_json(js, nbest_hyps, char_list):
logger.info("groundtruth: %s" % out_dic["text"]) logger.info("groundtruth: %s" % out_dic["text"])
logger.info("prediction : %s" % out_dic["rec_text"]) logger.info("prediction : %s" % out_dic["rec_text"])
return new_js return new_js

@ -11,3 +11,52 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from deepspeech.training.trainer import Trainer
from deepspeech.utils.dynamic_import import dynamic_import
model_trainer_alias = {
"ds2": "deepspeech.exp.deepspeech2.model:DeepSpeech2Trainer",
"u2": "deepspeech.exps.u2.model:U2Trainer",
"u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Trainer",
"u2_st": "deepspeech.exps.u2_st.model:U2STTrainer",
}
def dynamic_import_trainer(module):
"""Import Trainer dynamically.
Args:
module (str): trainer name. e.g., ds2, u2, u2_kaldi
Returns:
type: Trainer class
"""
model_class = dynamic_import(module, model_trainer_alias)
assert issubclass(model_class,
Trainer), f"{module} does not implement Trainer"
return model_class
model_tester_alias = {
"ds2": "deepspeech.exp.deepspeech2.model:DeepSpeech2Tester",
"u2": "deepspeech.exps.u2.model:U2Tester",
"u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester",
"u2_st": "deepspeech.exps.u2_st.model:U2STTester",
}
def dynamic_import_tester(module):
"""Import Tester dynamically.
Args:
module (str): tester name. e.g., ds2, u2, u2_kaldi
Returns:
type: Tester class
"""
model_class = dynamic_import(module, model_tester_alias)
assert issubclass(model_class,
Trainer), f"{module} does not implement Tester"
return model_class

@ -1,379 +1,19 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
"""End-to-end speech recognition model decoding script.""" #
# Licensed under the Apache License, Version 2.0 (the "License");
import configargparse # you may not use this file except in compliance with the License.
import logging # You may obtain a copy of the License at
import os #
import random # http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import numpy as np from deepspeech.decoders.recog_bin import main
from distutils.util import strtobool
from deepspeech.training.cli import default_argument_parser
# NOTE: you need this func to generate our sphinx doc
def get_parser():
"""Get default arguments."""
parser = configargparse.ArgumentParser(
description="Transcribe text from speech using "
"a speech recognition model on one CPU or GPU",
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
)
parser.add(
'--model-name',
type=str,
default='u2_kaldi',
help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
# general configuration
parser.add("--config", is_config_file=True, help="Config file path")
parser.add(
"--config2",
is_config_file=True,
help="Second config file path that overwrites the settings in `--config`",
)
parser.add(
"--config3",
is_config_file=True,
help="Third config file path that overwrites the settings "
"in `--config` and `--config2`",
)
parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
parser.add_argument(
"--dtype",
choices=("float16", "float32", "float64"),
default="float32",
help="Float precision (only available in --api v2)",
)
parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
parser.add_argument("--seed", type=int, default=1, help="Random seed")
parser.add_argument("--verbose", "-V", type=int, default=2, help="Verbose option")
parser.add_argument(
"--batchsize",
type=int,
default=1,
help="Batch size for beam search (0: means no batch processing)",
)
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
parser.add_argument(
"--api",
default="v2",
choices=["v2"],
help="Beam search APIs "
"v2: Experimental API. It supports any models that implements ScorerInterface.",
)
# task related
parser.add_argument(
"--recog-json", type=str, help="Filename of recognition data (json)"
)
parser.add_argument(
"--result-label",
type=str,
required=True,
help="Filename of result label data (json)",
)
# model (parameter) related
parser.add_argument(
"--model", type=str, required=True, help="Model file parameters to read"
)
parser.add_argument(
"--model-conf", type=str, default=None, help="Model config file"
)
parser.add_argument(
"--num-spkrs",
type=int,
default=1,
choices=[1, 2],
help="Number of speakers in the speech",
)
parser.add_argument(
"--num-encs", default=1, type=int, help="Number of encoders in the model."
)
# search related
parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty")
parser.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="""Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths.
If maxlenratio<0.0, its absolute value is interpreted
as a constant max output length""",
)
parser.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
parser.add_argument(
"--ctc-weight", type=float, default=0.0, help="CTC weight in joint decoding"
)
parser.add_argument(
"--weights-ctc-dec",
type=float,
action="append",
help="ctc weight assigned to each encoder during decoding."
"[in multi-encoder mode only]",
)
parser.add_argument(
"--ctc-window-margin",
type=int,
default=0,
help="""Use CTC window with margin parameter to accelerate
CTC/attention decoding especially on GPU. Smaller magin
makes decoding faster, but may increase search errors.
If margin=0 (default), this function is disabled""",
)
# transducer related
parser.add_argument(
"--search-type",
type=str,
default="default",
choices=["default", "nsc", "tsd", "alsd", "maes"],
help="""Type of beam search implementation to use during inference.
Can be either: default beam search ("default"),
N-Step Constrained beam search ("nsc"), Time-Synchronous Decoding ("tsd"),
Alignment-Length Synchronous Decoding ("alsd") or
modified Adaptive Expansion Search ("maes").""",
)
parser.add_argument(
"--nstep",
type=int,
default=1,
help="""Number of expansion steps allowed in NSC beam search or mAES
(nstep > 0 for NSC and nstep > 1 for mAES).""",
)
parser.add_argument(
"--prefix-alpha",
type=int,
default=2,
help="Length prefix difference allowed in NSC beam search or mAES.",
)
parser.add_argument(
"--max-sym-exp",
type=int,
default=2,
help="Number of symbol expansions allowed in TSD.",
)
parser.add_argument(
"--u-max",
type=int,
default=400,
help="Length prefix difference allowed in ALSD.",
)
parser.add_argument(
"--expansion-gamma",
type=float,
default=2.3,
help="Allowed logp difference for prune-by-value method in mAES.",
)
parser.add_argument(
"--expansion-beta",
type=int,
default=2,
help="""Number of additional candidates for expanded hypotheses
selection in mAES.""",
)
parser.add_argument(
"--score-norm",
type=strtobool,
nargs="?",
default=True,
help="Normalize final hypotheses' score by length",
)
parser.add_argument(
"--softmax-temperature",
type=float,
default=1.0,
help="Penalization term for softmax function.",
)
# rnnlm related
parser.add_argument(
"--rnnlm", type=str, default=None, help="RNNLM model file to read"
)
parser.add_argument(
"--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read"
)
parser.add_argument(
"--word-rnnlm", type=str, default=None, help="Word RNNLM model file to read"
)
parser.add_argument(
"--word-rnnlm-conf",
type=str,
default=None,
help="Word RNNLM model config file to read",
)
parser.add_argument("--word-dict", type=str, default=None, help="Word list to read")
parser.add_argument("--lm-weight", type=float, default=0.1, help="RNNLM weight")
# ngram related
parser.add_argument(
"--ngram-model", type=str, default=None, help="ngram model file to read"
)
parser.add_argument("--ngram-weight", type=float, default=0.1, help="ngram weight")
parser.add_argument(
"--ngram-scorer",
type=str,
default="part",
choices=("full", "part"),
help="""if the ngram is set as a part scorer, similar with CTC scorer,
ngram scorer only scores topK hypethesis.
if the ngram is set as full scorer, ngram scorer scores all hypthesis
the decoding speed of part scorer is musch faster than full one""",
)
# streaming related
parser.add_argument(
"--streaming-mode",
type=str,
default=None,
choices=["window", "segment"],
help="""Use streaming recognizer for inference.
`--batchsize` must be set to 0 to enable this mode""",
)
parser.add_argument("--streaming-window", type=int, default=10, help="Window size")
parser.add_argument(
"--streaming-min-blank-dur",
type=int,
default=10,
help="Minimum blank duration threshold",
)
parser.add_argument(
"--streaming-onset-margin", type=int, default=1, help="Onset margin"
)
parser.add_argument(
"--streaming-offset-margin", type=int, default=1, help="Offset margin"
)
# non-autoregressive related
# Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail.
parser.add_argument(
"--maskctc-n-iterations",
type=int,
default=10,
help="Number of decoding iterations."
"For Mask CTC, set 0 to predict 1 mask/iter.",
)
parser.add_argument(
"--maskctc-probability-threshold",
type=float,
default=0.999,
help="Threshold probability for CTC output",
)
# quantize model related
parser.add_argument(
"--quantize-config",
nargs="*",
help="Quantize config list. E.g.: --quantize-config=[Linear,LSTM,GRU]",
)
parser.add_argument(
"--quantize-dtype", type=str, default="qint8", help="Dtype dynamic quantize"
)
parser.add_argument(
"--quantize-asr-model",
type=bool,
default=False,
help="Quantize asr model",
)
parser.add_argument(
"--quantize-lm-model",
type=bool,
default=False,
help="Quantize lm model",
)
return parser
def main(args):
"""Run the main decoding function."""
parser = get_parser()
parser.add_argument(
"--output", metavar="CKPT_DIR", help="path to save checkpoint.")
parser.add_argument(
"--checkpoint_path", type=str, help="path to load checkpoint")
parser.add_argument(
"--dict-path", type=str, help="path to load checkpoint")
# parser = default_argument_parser(parser)
args = parser.parse_args(args)
if args.ngpu == 0 and args.dtype == "float16":
raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.")
# logging info
if args.verbose == 1:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
elif args.verbose == 2:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.warning("Skip DEBUG/INFO messages")
logging.info(args)
# check CUDA_VISIBLE_DEVICES
if args.ngpu > 0:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu != len(cvd.split(",")):
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
sys.exit(1)
# TODO(mn5k): support of multiple GPUs
if args.ngpu > 1:
logging.error("The program only supports ngpu=1.")
sys.exit(1)
# display PYTHONPATH
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
# seed setting
random.seed(args.seed)
np.random.seed(args.seed)
logging.info("set random seed = %d" % args.seed)
# validate rnn options
if args.rnnlm is not None and args.word_rnnlm is not None:
logging.error(
"It seems that both --rnnlm and --word-rnnlm are specified. "
"Please use either option."
)
sys.exit(1)
# recog
if args.num_spkrs == 1:
if args.num_encs == 1:
# Experimental API that supports custom LMs
if args.api == "v2":
from deepspeech.decoders.recog import recog_v2
recog_v2(args)
else:
raise ValueError("Only support --api v2")
else:
if args.api == "v2":
raise NotImplementedError(
f"--num-encs {args.num_encs} > 1 is not supported in --api v2"
)
elif args.num_spkrs == 2:
raise ValueError("asr_mix not supported.")
if __name__ == "__main__": if __name__ == "__main__":
main(sys.argv[1:]) main(sys.argv[1:])

@ -434,8 +434,9 @@ class U2Tester(U2Trainer):
simulate_streaming=cfg.simulate_streaming) simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time decode_time = time.time() - start_time
for i, (utt, target, result, rec_tids) in enumerate(zip( for i, (utt, target, result, rec_tids) in enumerate(
utts, target_transcripts, result_transcripts, result_tokenids)): zip(utts, target_transcripts, result_transcripts,
result_tokenids)):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
errors_sum += errors errors_sum += errors
len_refs += len_ref len_refs += len_ref

@ -140,7 +140,7 @@ class TextFeaturizer():
Returns: Returns:
str: text string. str: text string.
""" """
tokens = [t.replace(SPACE, " ") for t in tokens ] tokens = [t.replace(SPACE, " ") for t in tokens]
return "".join(tokens) return "".join(tokens)
def word_tokenize(self, text): def word_tokenize(self, text):

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ASR Interface module.""" """ASR Interface module."""
import argparse import argparse
@ -72,7 +85,8 @@ class ASRInterface:
:return: attention weights (B, Lmax, Tmax) :return: attention weights (B, Lmax, Tmax)
:rtype: float ndarray :rtype: float ndarray
""" """
raise NotImplementedError("calculate_all_attentions method is not implemented") raise NotImplementedError(
"calculate_all_attentions method is not implemented")
def calculate_all_ctc_probs(self, xs, ilens, ys): def calculate_all_ctc_probs(self, xs, ilens, ys):
"""Calculate CTC probability. """Calculate CTC probability.
@ -83,7 +97,8 @@ class ASRInterface:
:return: CTC probabilities (B, Tmax, vocab) :return: CTC probabilities (B, Tmax, vocab)
:rtype: float ndarray :rtype: float ndarray
""" """
raise NotImplementedError("calculate_all_ctc_probs method is not implemented") raise NotImplementedError(
"calculate_all_ctc_probs method is not implemented")
@property @property
def attention_plot_class(self): def attention_plot_class(self):
@ -102,8 +117,7 @@ class ASRInterface:
def get_total_subsampling_factor(self): def get_total_subsampling_factor(self):
"""Get total subsampling factor.""" """Get total subsampling factor."""
raise NotImplementedError( raise NotImplementedError(
"get_total_subsampling_factor method is not implemented" "get_total_subsampling_factor method is not implemented")
)
def encode(self, feat): def encode(self, feat):
"""Encode feature in `beam_search` (optional). """Encode feature in `beam_search` (optional).
@ -126,23 +140,22 @@ class ASRInterface:
predefined_asr = { predefined_asr = {
"transformer": "deepspeech.models.u2:E2E", "transformer": "deepspeech.models.u2:U2Model",
"conformer": "deepspeech.models.u2:E2E", "conformer": "deepspeech.models.u2:U2Model",
} }
def dynamic_import_asr(module, name):
def dynamic_import_asr(module):
"""Import ASR models dynamically. """Import ASR models dynamically.
Args: Args:
module (str): module_name:class_name or alias in `predefined_asr` module (str): asr name. e.g., transformer, conformer
name (str): asr name. e.g., transformer, conformer
Returns: Returns:
type: ASR class type: ASR class
""" """
model_class = dynamic_import(module, predefined_asr.get(name, "")) model_class = dynamic_import(module, predefined_asr)
assert issubclass( assert issubclass(model_class,
model_class, ASRInterface ASRInterface), f"{module} does not implement ASRInterface"
), f"{module} does not implement ASRInterface"
return model_class return model_class

@ -28,8 +28,10 @@ from paddle import jit
from paddle import nn from paddle import nn
from yacs.config import CfgNode from yacs.config import CfgNode
from deepspeech.decoders.scorers.ctc import CTCPrefixScorer
from deepspeech.frontend.utility import IGNORE_ID from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.frontend.utility import load_cmvn from deepspeech.frontend.utility import load_cmvn
from deepspeech.models.asr_interface import ASRInterface
from deepspeech.modules.cmvn import GlobalCMVN from deepspeech.modules.cmvn import GlobalCMVN
from deepspeech.modules.ctc import CTCDecoder from deepspeech.modules.ctc import CTCDecoder
from deepspeech.modules.decoder import TransformerDecoder from deepspeech.modules.decoder import TransformerDecoder
@ -49,8 +51,6 @@ from deepspeech.utils.tensor_utils import pad_sequence
from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.tensor_utils import th_accuracy
from deepspeech.utils.utility import log_add from deepspeech.utils.utility import log_add
from deepspeech.utils.utility import UpdateConfig from deepspeech.utils.utility import UpdateConfig
from deepspeech.models.asr_interface import ASRInterface
from deepspeech.decoders.scorers.ctc import CTCPrefixScorer
__all__ = ["U2Model", "U2InferModel"] __all__ = ["U2Model", "U2InferModel"]
@ -816,10 +816,10 @@ class U2BaseModel(ASRInterface, nn.Layer):
class U2DecodeModel(U2BaseModel): class U2DecodeModel(U2BaseModel):
def scorers(self): def scorers(self):
"""Scorers.""" """Scorers."""
return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) return dict(
decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))
def encode(self, x): def encode(self, x):
"""Encode acoustic features. """Encode acoustic features.

@ -12,23 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Decoder definition.""" """Decoder definition."""
from typing import Any
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
from typing import Any
import paddle import paddle
from paddle import nn from paddle import nn
from typeguard import check_argument_types from typeguard import check_argument_types
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
from deepspeech.modules.attention import MultiHeadedAttention from deepspeech.modules.attention import MultiHeadedAttention
from deepspeech.modules.decoder_layer import DecoderLayer from deepspeech.modules.decoder_layer import DecoderLayer
from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.mask import make_non_pad_mask from deepspeech.modules.mask import make_non_pad_mask
from deepspeech.modules.mask import subsequent_mask
from deepspeech.modules.mask import make_xs_mask from deepspeech.modules.mask import make_xs_mask
from deepspeech.modules.mask import subsequent_mask
from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -191,8 +191,8 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
ys: (ylen,) ys: (ylen,)
x: (xlen, n_feat) x: (xlen, n_feat)
""" """
ys_mask = subsequent_mask(len(ys)).unsqueeze(0) # (B,L,L) ys_mask = subsequent_mask(len(ys)).unsqueeze(0) # (B,L,L)
x_mask = make_xs_mask(x.unsqueeze(0)).unsqueeze(1) # (B,1,T) x_mask = make_xs_mask(x.unsqueeze(0)).unsqueeze(1) # (B,1,T)
if self.selfattention_layer_type != "selfattn": if self.selfattention_layer_type != "selfattn":
# TODO(karita): implement cache # TODO(karita): implement cache
logging.warning( logging.warning(
@ -200,16 +200,14 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
) )
state = None state = None
logp, state = self.forward_one_step( logp, state = self.forward_one_step(
x.unsqueeze(0), x_mask, x.unsqueeze(0), x_mask, ys.unsqueeze(0), ys_mask, cache=state)
ys.unsqueeze(0), ys_mask,
cache=state
)
return logp.squeeze(0), state return logp.squeeze(0), state
# batch beam search API (see BatchScorerInterface) # batch beam search API (see BatchScorerInterface)
def batch_score( def batch_score(self,
self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor ys: paddle.Tensor,
) -> Tuple[paddle.Tensor, List[Any]]: states: List[Any],
xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]:
"""Score new token batch (required). """Score new token batch (required).
Args: Args:
@ -237,10 +235,12 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer):
] ]
# batch decoding # batch decoding
ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) # (B,L,L) ys_mask = subsequent_mask(ys.size(-1)).unsqueeze(0) # (B,L,L)
xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T) xs_mask = make_xs_mask(xs).unsqueeze(1) # (B,1,T)
logp, states = self.forward_one_step(xs, xs_mask, ys, ys_mask, cache=batch_state) logp, states = self.forward_one_step(
xs, xs_mask, ys, ys_mask, cache=batch_state)
# transpose state of [layer, batch] into [batch, layer] # transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] state_list = [[states[i][b] for i in range(n_layers)]
for b in range(n_batch)]
return logp, state_list return logp, state_list

@ -24,7 +24,7 @@ __all__ = [
] ]
def make_xs_mask(xs:paddle.Tensor, pad_value=0.0) -> paddle.Tensor: def make_xs_mask(xs: paddle.Tensor, pad_value=0.0) -> paddle.Tensor:
"""Maks mask tensor containing indices of non-padded part. """Maks mask tensor containing indices of non-padded part.
Args: Args:
xs (paddle.Tensor): (B, T, D), zeros for pad. xs (paddle.Tensor): (B, T, D), zeros for pad.

@ -64,7 +64,7 @@ def default_argument_parser(parser=None):
""" """
if parser is None: if parser is None:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.register('action', 'extend', ExtendAction) parser.register('action', 'extend', ExtendAction)
parser.add_argument( parser.add_argument(
'--conf', type=open, action=LoadFromFile, help="config file.") '--conf', type=open, action=LoadFromFile, help="config file.")

@ -126,7 +126,8 @@ class Trainer():
logger.info(f"Set seed {args.seed}") logger.info(f"Set seed {args.seed}")
# profiler and benchmark options # profiler and benchmark options
if hasattr(self.args, "benchmark_batch_size") and self.args.benchmark_batch_size: if hasattr(self.args,
"benchmark_batch_size") and self.args.benchmark_batch_size:
with UpdateConfig(self.config): with UpdateConfig(self.config):
self.config.collator.batch_size = self.args.benchmark_batch_size self.config.collator.batch_size = self.args.benchmark_batch_size
self.config.training.log_interval = 1 self.config.training.log_interval = 1
@ -335,8 +336,7 @@ class Trainer():
""" """
assert self.args.checkpoint_path assert self.args.checkpoint_path
infos = self.checkpoint.load_latest_parameters( infos = self.checkpoint.load_latest_parameters(
self.model, self.model, checkpoint_path=self.args.checkpoint_path)
checkpoint_path=self.args.checkpoint_path)
return infos return infos
def run_test(self): def run_test(self):

@ -1,8 +1,8 @@
# ASR # ASR
* s0 is for deepspeech2 offline * s0 is for deepspeech2 offline
* s1 is for transformer/conformer/U2 * s1 is for transformer/conformer/U2
* s2 is for transformer/conformer/U2 w/ kaldi feat, need install Kaldi * s2 is for transformer/conformer/U2 w/ kaldi feat, need install Kaldi
## Data ## Data
| Data Subset | Duration in Seconds | | Data Subset | Duration in Seconds |

@ -1,14 +1,14 @@
# LibriSpeech # LibriSpeech
| Model | Params | Config | Augmentation| Loss | | Model | Params | Config | Augmentation| Loss |
| --- | --- | --- | --- | | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | 6.3197922706604 | | transformer | 32.52 M | conf/transformer.yaml | spec_aug | 6.3197922706604 |
| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err | | Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| test-clean | attention | 2620 | 52576 | 96.4 | 2.5 | 1.1 | 0.4 | 4.0 | 34.7 | | test-clean | attention | 2620 | 52576 | 96.4 | 2.5 | 1.1 | 0.4 | 4.0 | 34.7 |
| test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 | | test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 |
| test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 | | test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 |
| test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 | | test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 |
| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 | | test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 |

@ -1,6 +1,6 @@
batchsize: 0 batchsize: 0
beam-size: 60 beam-size: 60
ctc-weight: 0.4 ctc-weight: 0.0
lm-weight: 0.0 lm-weight: 0.0
maxlenratio: 0.0 maxlenratio: 0.0
minlenratio: 0.0 minlenratio: 0.0

@ -5,11 +5,14 @@ set -e
expdir=exp expdir=exp
datadir=data datadir=data
nj=32 nj=32
tag=
# decode config
decode_config=conf/decode/decode.yaml decode_config=conf/decode/decode.yaml
# lm params
lang_model=rnnlm.model.best lang_model=rnnlm.model.best
lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/ lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/
lmtag='nolm' lmtag='nolm'
recog_set="test-clean test-other dev-clean dev-other" recog_set="test-clean test-other dev-clean dev-other"
@ -21,18 +24,21 @@ bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}" bpeprefix="data/bpe_${bpemode}_${nbpe}"
bpemodel=${bpeprefix}.model bpemodel=${bpeprefix}.model
if [ $# != 3 ];then # bin params
echo "usage: ${0} config_path dict_path ckpt_path_prefix" config_path=conf/transformer.yaml
exit -1 dict=data/bpe_unigram_5000_units.txt
ckpt_prefix=
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
if [ -z ${ckpt_prefix} ]; then
echo "usage: $0 --ckpt_prefix ckpt_prefix"
exit 1
fi fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
config_path=$1
dict=$2
ckpt_prefix=$3
ckpt_dir=$(dirname `dirname ${ckpt_prefix}`) ckpt_dir=$(dirname `dirname ${ckpt_prefix}`)
echo "ckpt dir: ${ckpt_dir}" echo "ckpt dir: ${ckpt_dir}"
@ -61,7 +67,7 @@ for dmethd in join_ctc; do
for rtask in ${recog_set}; do for rtask in ${recog_set}; do
( (
echo "${rtask} dataset" echo "${rtask} dataset"
decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag} decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag}_${tag}
feat_recog_dir=${datadir} feat_recog_dir=${datadir}
mkdir -p ${decode_dir} mkdir -p ${decode_dir}
mkdir -p ${feat_recog_dir} mkdir -p ${feat_recog_dir}

@ -17,19 +17,20 @@ bpemode=unigram
bpeprefix="data/bpe_${bpemode}_${nbpe}" bpeprefix="data/bpe_${bpemode}_${nbpe}"
bpemodel=${bpeprefix}.model bpemodel=${bpeprefix}.model
if [ $# != 3 ];then config_path=conf/transformer.yaml
echo "usage: ${0} config_path dict_path ckpt_path_prefix" dict=data/bpe_unigram_5000_units.txt
exit -1 ckpt_prefix=
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
if [ -z ${ckpt_prefix} ]; then
echo "usage: $0 --ckpt_prefix ckpt_prefix"
exit 1
fi fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
config_path=$1
dict=$2
ckpt_prefix=$3
ckpt_dir=$(dirname `dirname ${ckpt_prefix}`) ckpt_dir=$(dirname `dirname ${ckpt_prefix}`)
echo "ckpt dir: ${ckpt_dir}" echo "ckpt dir: ${ckpt_dir}"

@ -1,43 +1,43 @@
ConfigArgParse
coverage coverage
editdistance editdistance
g2p_en
g2pM
gpustat gpustat
h5py
inflect
jieba
jsonlines jsonlines
kaldiio kaldiio
librosa
llvmlite
loguru loguru
matplotlib
nltk
numba
numpy==1.20.0
pandas
phkit
Pillow Pillow
praatio~=4.1
pre-commit pre-commit
pybind11 pybind11
pypinyin
pyworld
resampy==0.2.2 resampy==0.2.2
sacrebleu sacrebleu
scipy==1.2.1 scipy==1.2.1
sentencepiece sentencepiece
snakeviz snakeviz
soundfile~=0.10
sox sox
tensorboardX tensorboardX
textgrid textgrid
timer
tqdm tqdm
typeguard typeguard
visualdl==2.2.0
yacs
numpy==1.20.0
numba
nltk
inflect
librosa
unidecode unidecode
llvmlite visualdl==2.2.0
matplotlib
pandas
soundfile~=0.10
g2p_en
pypinyin
webrtcvad webrtcvad
g2pM yacs
praatio~=4.1
h5py
timer
pyworld
jieba
phkit
yq yq
ConfigArgParse

@ -11,20 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import inspect
import io import io
import os import os
import re import re
import subprocess as sp
import sys import sys
from pathlib import Path from pathlib import Path
import contextlib
import inspect
from setuptools import Command
from setuptools import find_packages from setuptools import find_packages
from setuptools import setup from setuptools import setup
from setuptools import Command
from setuptools.command.develop import develop from setuptools.command.develop import develop
from setuptools.command.install import install from setuptools.command.install import install
import subprocess as sp
HERE = Path(os.path.abspath(os.path.dirname(__file__))) HERE = Path(os.path.abspath(os.path.dirname(__file__)))
@ -40,16 +40,18 @@ def pushd(new_dir):
def read(*names, **kwargs): def read(*names, **kwargs):
with io.open(os.path.join(os.path.dirname(__file__), *names), with io.open(
encoding=kwargs.get("encoding", "utf8")) as fp: os.path.join(os.path.dirname(__file__), *names),
encoding=kwargs.get("encoding", "utf8")) as fp:
return fp.read() return fp.read()
def check_call(cmd: str, shell=False, executable=None): def check_call(cmd: str, shell=False, executable=None):
try: try:
sp.check_call(cmd.split(), sp.check_call(
shell=shell, cmd.split(),
executable="/bin/bash" if shell else executable) shell=shell,
executable="/bin/bash" if shell else executable)
except sp.CalledProcessError as e: except sp.CalledProcessError as e:
print( print(
f"{__file__}:{inspect.currentframe().f_lineno}: CMD: {cmd}, Error:", f"{__file__}:{inspect.currentframe().f_lineno}: CMD: {cmd}, Error:",
@ -189,7 +191,6 @@ setup_info = dict(
'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.8',
], ], )
)
setup(**setup_info) setup(**setup_info)

Loading…
Cancel
Save