commit
2e9d9dc9a7
@ -0,0 +1,17 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .batch_beam_search import BatchBeamSearch
|
||||
from .beam_search import beam_search
|
||||
from .beam_search import BeamSearch
|
||||
from .beam_search import Hypothesis
|
@ -0,0 +1,17 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class BatchBeamSearch():
|
||||
pass
|
@ -0,0 +1,530 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Beam search module."""
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import NamedTuple
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import paddle
|
||||
|
||||
from ..scorers.scorer_interface import PartialScorerInterface
|
||||
from ..scorers.scorer_interface import ScorerInterface
|
||||
from ..utils import end_detect
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class Hypothesis(NamedTuple):
|
||||
"""Hypothesis data type."""
|
||||
|
||||
yseq: paddle.Tensor # (T,)
|
||||
score: Union[float, paddle.Tensor] = 0
|
||||
scores: Dict[str, Union[float, paddle.Tensor]] = dict()
|
||||
states: Dict[str, Any] = dict()
|
||||
|
||||
def asdict(self) -> dict:
|
||||
"""Convert data to JSON-friendly dict."""
|
||||
return self._replace(
|
||||
yseq=self.yseq.tolist(),
|
||||
score=float(self.score),
|
||||
scores={k: float(v)
|
||||
for k, v in self.scores.items()}, )._asdict()
|
||||
|
||||
|
||||
class BeamSearch(paddle.nn.Layer):
|
||||
"""Beam search implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scorers: Dict[str, ScorerInterface],
|
||||
weights: Dict[str, float],
|
||||
beam_size: int,
|
||||
vocab_size: int,
|
||||
sos: int,
|
||||
eos: int,
|
||||
token_list: List[str]=None,
|
||||
pre_beam_ratio: float=1.5,
|
||||
pre_beam_score_key: str=None, ):
|
||||
"""Initialize beam search.
|
||||
|
||||
Args:
|
||||
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
||||
e.g., Decoder, CTCPrefixScorer, LM
|
||||
The scorer will be ignored if it is `None`
|
||||
weights (dict[str, float]): Dict of weights for each scorers
|
||||
The scorer will be ignored if its weight is 0
|
||||
beam_size (int): The number of hypotheses kept during search
|
||||
vocab_size (int): The number of vocabulary
|
||||
sos (int): Start of sequence id
|
||||
eos (int): End of sequence id
|
||||
token_list (list[str]): List of tokens for debug log
|
||||
pre_beam_score_key (str): key of scores to perform pre-beam search
|
||||
pre_beam_ratio (float): beam size in the pre-beam search
|
||||
will be `int(pre_beam_ratio * beam_size)`
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
# set scorers
|
||||
self.weights = weights
|
||||
self.scorers = dict() # all = full + partial
|
||||
self.full_scorers = dict() # full tokens
|
||||
self.part_scorers = dict() # partial tokens
|
||||
# this module dict is required for recursive cast
|
||||
# `self.to(device, dtype)` in `recog.py`
|
||||
self.nn_dict = paddle.nn.LayerDict() # nn.Layer
|
||||
for k, v in scorers.items():
|
||||
w = weights.get(k, 0)
|
||||
if w == 0 or v is None:
|
||||
continue
|
||||
assert isinstance(
|
||||
v, ScorerInterface
|
||||
), f"{k} ({type(v)}) does not implement ScorerInterface"
|
||||
self.scorers[k] = v
|
||||
if isinstance(v, PartialScorerInterface):
|
||||
self.part_scorers[k] = v
|
||||
else:
|
||||
self.full_scorers[k] = v
|
||||
if isinstance(v, paddle.nn.Layer):
|
||||
self.nn_dict[k] = v
|
||||
|
||||
# set configurations
|
||||
self.sos = sos
|
||||
self.eos = eos
|
||||
self.token_list = token_list
|
||||
# pre_beam_size > beam_size
|
||||
self.pre_beam_size = int(pre_beam_ratio * beam_size)
|
||||
self.beam_size = beam_size
|
||||
self.n_vocab = vocab_size
|
||||
if (pre_beam_score_key is not None and pre_beam_score_key != "full" and
|
||||
pre_beam_score_key not in self.full_scorers):
|
||||
raise KeyError(
|
||||
f"{pre_beam_score_key} is not found in {self.full_scorers}")
|
||||
# selected `key` scorer to do pre beam search
|
||||
self.pre_beam_score_key = pre_beam_score_key
|
||||
# do_pre_beam when need, valid and has part_scorers
|
||||
self.do_pre_beam = (self.pre_beam_score_key is not None and
|
||||
self.pre_beam_size < self.n_vocab and
|
||||
len(self.part_scorers) > 0)
|
||||
|
||||
def init_hyp(self, x: paddle.Tensor) -> List[Hypothesis]:
|
||||
"""Get an initial hypothesis data.
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): The encoder output feature, (T, D)
|
||||
|
||||
Returns:
|
||||
Hypothesis: The initial hypothesis.
|
||||
|
||||
"""
|
||||
init_states = dict()
|
||||
init_scores = dict()
|
||||
for k, d in self.scorers.items():
|
||||
init_states[k] = d.init_state(x)
|
||||
init_scores[k] = 0.0
|
||||
return [
|
||||
Hypothesis(
|
||||
yseq=paddle.to_tensor([self.sos], place=x.place),
|
||||
score=0.0,
|
||||
scores=init_scores,
|
||||
states=init_states, )
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def append_token(xs: paddle.Tensor,
|
||||
x: Union[int, paddle.Tensor]) -> paddle.Tensor:
|
||||
"""Append new token to prefix tokens.
|
||||
|
||||
Args:
|
||||
xs (paddle.Tensor): The prefix token, (T,)
|
||||
x (int): The new token to append
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: (T+1,), New tensor contains: xs + [x] with xs.dtype and xs.device
|
||||
|
||||
"""
|
||||
x = paddle.to_tensor([x], dtype=xs.dtype) if isinstance(x, int) else x
|
||||
return paddle.concat((xs, x))
|
||||
|
||||
def score_full(self, hyp: Hypothesis, x: paddle.Tensor
|
||||
) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]:
|
||||
"""Score new hypothesis by `self.full_scorers`.
|
||||
|
||||
Args:
|
||||
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||
x (paddle.Tensor): Corresponding input feature, (T, D)
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: Tuple of
|
||||
score dict of `hyp` that has string keys of `self.full_scorers`
|
||||
and tensor score values of shape: `(self.n_vocab,)`,
|
||||
and state dict that has string keys
|
||||
and state values of `self.full_scorers`
|
||||
|
||||
"""
|
||||
scores = dict()
|
||||
states = dict()
|
||||
for k, d in self.full_scorers.items():
|
||||
# scores[k] shape (self.n_vocab,)
|
||||
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
|
||||
return scores, states
|
||||
|
||||
def score_partial(self,
|
||||
hyp: Hypothesis,
|
||||
ids: paddle.Tensor,
|
||||
x: paddle.Tensor
|
||||
) -> Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]:
|
||||
"""Score new hypothesis by `self.part_scorers`.
|
||||
|
||||
Args:
|
||||
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
||||
ids (paddle.Tensor): 1D tensor of new partial tokens to score,
|
||||
len(ids) < n_vocab
|
||||
x (paddle.Tensor): Corresponding input feature, (T, D)
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, paddle.Tensor], Dict[str, Any]]: Tuple of
|
||||
score dict of `hyp` that has string keys of `self.part_scorers`
|
||||
and tensor score values of shape: `(len(ids),)`,
|
||||
and state dict that has string keys
|
||||
and state values of `self.part_scorers`
|
||||
|
||||
"""
|
||||
scores = dict()
|
||||
states = dict()
|
||||
for k, d in self.part_scorers.items():
|
||||
# scores[k] shape (len(ids),)
|
||||
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k],
|
||||
x)
|
||||
return scores, states
|
||||
|
||||
def beam(self, weighted_scores: paddle.Tensor,
|
||||
ids: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Compute topk full token ids and partial token ids.
|
||||
|
||||
Args:
|
||||
weighted_scores (paddle.Tensor): The weighted sum scores for each tokens.
|
||||
Its shape is `(self.n_vocab,)`.
|
||||
ids (paddle.Tensor): The partial token ids(Global) to compute topk.
|
||||
|
||||
Returns:
|
||||
Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
The topk full token ids and partial token ids.
|
||||
Their shapes are `(self.beam_size,)`.
|
||||
i.e. (global ids, global relative local ids).
|
||||
|
||||
"""
|
||||
# no pre beam performed, `ids` equal to `weighted_scores`
|
||||
if weighted_scores.size(0) == ids.size(0):
|
||||
top_ids = weighted_scores.topk(
|
||||
self.beam_size)[1] # index in n_vocab
|
||||
return top_ids, top_ids
|
||||
|
||||
# mask pruned in pre-beam not to select in topk
|
||||
tmp = weighted_scores[ids]
|
||||
weighted_scores[:] = -float("inf")
|
||||
weighted_scores[ids] = tmp
|
||||
# top_ids no equal to local_ids, since ids shape not same
|
||||
top_ids = weighted_scores.topk(self.beam_size)[1] # index in n_vocab
|
||||
local_ids = weighted_scores[ids].topk(
|
||||
self.beam_size)[1] # index in len(ids)
|
||||
return top_ids, local_ids
|
||||
|
||||
@staticmethod
|
||||
def merge_scores(
|
||||
prev_scores: Dict[str, float],
|
||||
next_full_scores: Dict[str, paddle.Tensor],
|
||||
full_idx: int,
|
||||
next_part_scores: Dict[str, paddle.Tensor],
|
||||
part_idx: int, ) -> Dict[str, paddle.Tensor]:
|
||||
"""Merge scores for new hypothesis.
|
||||
|
||||
Args:
|
||||
prev_scores (Dict[str, float]):
|
||||
The previous hypothesis scores by `self.scorers`
|
||||
next_full_scores (Dict[str, paddle.Tensor]): scores by `self.full_scorers`
|
||||
full_idx (int): The next token id for `next_full_scores`
|
||||
next_part_scores (Dict[str, paddle.Tensor]):
|
||||
scores of partial tokens by `self.part_scorers`
|
||||
part_idx (int): The new token id for `next_part_scores`
|
||||
|
||||
Returns:
|
||||
Dict[str, paddle.Tensor]: The new score dict.
|
||||
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
||||
Its values are scalar tensors by the scorers.
|
||||
|
||||
"""
|
||||
new_scores = dict()
|
||||
for k, v in next_full_scores.items():
|
||||
new_scores[k] = prev_scores[k] + v[full_idx]
|
||||
for k, v in next_part_scores.items():
|
||||
new_scores[k] = prev_scores[k] + v[part_idx]
|
||||
return new_scores
|
||||
|
||||
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
||||
"""Merge states for new hypothesis.
|
||||
|
||||
Args:
|
||||
states: states of `self.full_scorers`
|
||||
part_states: states of `self.part_scorers`
|
||||
part_idx (int): The new token id for `part_scores`
|
||||
|
||||
Returns:
|
||||
Dict[str, paddle.Tensor]: The new score dict.
|
||||
Its keys are names of `self.full_scorers` and `self.part_scorers`.
|
||||
Its values are states of the scorers.
|
||||
|
||||
"""
|
||||
new_states = dict()
|
||||
for k, v in states.items():
|
||||
new_states[k] = v
|
||||
for k, d in self.part_scorers.items():
|
||||
new_states[k] = d.select_state(part_states[k], part_idx)
|
||||
return new_states
|
||||
|
||||
def search(self, running_hyps: List[Hypothesis],
|
||||
x: paddle.Tensor) -> List[Hypothesis]:
|
||||
"""Search new tokens for running hypotheses and encoded speech x.
|
||||
|
||||
Args:
|
||||
running_hyps (List[Hypothesis]): Running hypotheses on beam
|
||||
x (paddle.Tensor): Encoded speech feature (T, D)
|
||||
|
||||
Returns:
|
||||
List[Hypotheses]: Best sorted hypotheses
|
||||
|
||||
"""
|
||||
best_hyps = []
|
||||
part_ids = paddle.arange(self.n_vocab) # no pre-beam
|
||||
for hyp in running_hyps:
|
||||
# scoring
|
||||
weighted_scores = paddle.zeros([self.n_vocab], dtype=x.dtype)
|
||||
scores, states = self.score_full(hyp, x)
|
||||
for k in self.full_scorers:
|
||||
weighted_scores += self.weights[k] * scores[k]
|
||||
# partial scoring
|
||||
if self.do_pre_beam:
|
||||
pre_beam_scores = (weighted_scores
|
||||
if self.pre_beam_score_key == "full" else
|
||||
scores[self.pre_beam_score_key])
|
||||
part_ids = paddle.topk(pre_beam_scores, self.pre_beam_size)[1]
|
||||
part_scores, part_states = self.score_partial(hyp, part_ids, x)
|
||||
for k in self.part_scorers:
|
||||
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
|
||||
# add previous hyp score
|
||||
weighted_scores += hyp.score
|
||||
|
||||
# update hyps
|
||||
for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
|
||||
# `part_j` is `j` relative id in `part_scores`
|
||||
# will be (2 x beam at most)
|
||||
best_hyps.append(
|
||||
Hypothesis(
|
||||
score=weighted_scores[j],
|
||||
yseq=self.append_token(hyp.yseq, j),
|
||||
scores=self.merge_scores(hyp.scores, scores, j,
|
||||
part_scores, part_j),
|
||||
states=self.merge_states(states, part_states, part_j),
|
||||
))
|
||||
|
||||
# sort and prune 2 x beam -> beam
|
||||
best_hyps = sorted(
|
||||
best_hyps, key=lambda x: x.score,
|
||||
reverse=True)[:min(len(best_hyps), self.beam_size)]
|
||||
return best_hyps
|
||||
|
||||
def forward(self,
|
||||
x: paddle.Tensor,
|
||||
maxlenratio: float=0.0,
|
||||
minlenratio: float=0.0) -> List[Hypothesis]:
|
||||
"""Perform beam search.
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): Encoded speech feature (T, D)
|
||||
maxlenratio (float): Input length ratio to obtain max output length.
|
||||
If maxlenratio=0.0 (default), it uses a end-detect function
|
||||
to automatically find maximum hypothesis lengths
|
||||
If maxlenratio<0.0, its absolute value is interpreted
|
||||
as a constant max output length.
|
||||
minlenratio (float): Input length ratio to obtain min output length.
|
||||
|
||||
Returns:
|
||||
list[Hypothesis]: N-best decoding results
|
||||
|
||||
"""
|
||||
# set length bounds
|
||||
if maxlenratio == 0:
|
||||
maxlen = x.shape[0]
|
||||
elif maxlenratio < 0:
|
||||
maxlen = -1 * int(maxlenratio)
|
||||
else:
|
||||
maxlen = max(1, int(maxlenratio * x.size(0)))
|
||||
minlen = int(minlenratio * x.size(0))
|
||||
logger.info("decoder input length: " + str(x.shape[0]))
|
||||
logger.info("max output length: " + str(maxlen))
|
||||
logger.info("min output length: " + str(minlen))
|
||||
|
||||
# main loop of prefix search
|
||||
running_hyps = self.init_hyp(x)
|
||||
ended_hyps = []
|
||||
for i in range(maxlen):
|
||||
logger.debug("position " + str(i))
|
||||
best = self.search(running_hyps, x)
|
||||
# post process of one iteration
|
||||
running_hyps = self.post_process(i, maxlen, maxlenratio, best,
|
||||
ended_hyps)
|
||||
# end detection
|
||||
if maxlenratio == 0.0 and end_detect(
|
||||
[h.asdict() for h in ended_hyps], i):
|
||||
logger.info(f"end detected at {i}")
|
||||
break
|
||||
if len(running_hyps) == 0:
|
||||
logger.info("no hypothesis. Finish decoding.")
|
||||
break
|
||||
else:
|
||||
logger.debug(f"remained hypotheses: {len(running_hyps)}")
|
||||
|
||||
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
|
||||
# check the number of hypotheses reaching to eos
|
||||
if len(nbest_hyps) == 0:
|
||||
logger.warning("there is no N-best results, perform recognition "
|
||||
"again with smaller minlenratio.")
|
||||
return ([] if minlenratio < 0.1 else
|
||||
self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)))
|
||||
|
||||
# report the best result
|
||||
best = nbest_hyps[0]
|
||||
for k, v in best.scores.items():
|
||||
logger.info(
|
||||
f"{float(v):6.2f} * {self.weights[k]:3} = {float(v) * self.weights[k]:6.2f} for {k}"
|
||||
)
|
||||
logger.info(f"total log probability: {float(best.score):.2f}")
|
||||
logger.info(
|
||||
f"normalized log probability: {float(best.score) / len(best.yseq):.2f}"
|
||||
)
|
||||
logger.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
|
||||
if self.token_list is not None:
|
||||
# logger.info(
|
||||
# "best hypo: "
|
||||
# + "".join([self.token_list[x] for x in best.yseq[1:-1]])
|
||||
# + "\n"
|
||||
# )
|
||||
logger.info("best hypo: " + "".join(
|
||||
[self.token_list[x] for x in best.yseq[1:]]) + "\n")
|
||||
return nbest_hyps
|
||||
|
||||
def post_process(
|
||||
self,
|
||||
i: int,
|
||||
maxlen: int,
|
||||
maxlenratio: float,
|
||||
running_hyps: List[Hypothesis],
|
||||
ended_hyps: List[Hypothesis], ) -> List[Hypothesis]:
|
||||
"""Perform post-processing of beam search iterations.
|
||||
|
||||
Args:
|
||||
i (int): The length of hypothesis tokens.
|
||||
maxlen (int): The maximum length of tokens in beam search.
|
||||
maxlenratio (int): The maximum length ratio in beam search.
|
||||
running_hyps (List[Hypothesis]): The running hypotheses in beam search.
|
||||
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
||||
|
||||
Returns:
|
||||
List[Hypothesis]: The new running hypotheses.
|
||||
|
||||
"""
|
||||
logger.debug(f"the number of running hypotheses: {len(running_hyps)}")
|
||||
if self.token_list is not None:
|
||||
logger.debug("best hypo: " + "".join(
|
||||
[self.token_list[x] for x in running_hyps[0].yseq[1:]]))
|
||||
# add eos in the final loop to avoid that there are no ended hyps
|
||||
if i == maxlen - 1:
|
||||
logger.info("adding <eos> in the last position in the loop")
|
||||
running_hyps = [
|
||||
h._replace(yseq=self.append_token(h.yseq, self.eos))
|
||||
for h in running_hyps
|
||||
]
|
||||
|
||||
# add ended hypotheses to a final list, and removed them from current hypotheses
|
||||
# (this will be a problem, number of hyps < beam)
|
||||
remained_hyps = []
|
||||
for hyp in running_hyps:
|
||||
if hyp.yseq[-1] == self.eos:
|
||||
# e.g., Word LM needs to add final <eos> score
|
||||
for k, d in chain(self.full_scorers.items(),
|
||||
self.part_scorers.items()):
|
||||
s = d.final_score(hyp.states[k])
|
||||
hyp.scores[k] += s
|
||||
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
|
||||
ended_hyps.append(hyp)
|
||||
else:
|
||||
remained_hyps.append(hyp)
|
||||
return remained_hyps
|
||||
|
||||
|
||||
def beam_search(
|
||||
x: paddle.Tensor,
|
||||
sos: int,
|
||||
eos: int,
|
||||
beam_size: int,
|
||||
vocab_size: int,
|
||||
scorers: Dict[str, ScorerInterface],
|
||||
weights: Dict[str, float],
|
||||
token_list: List[str]=None,
|
||||
maxlenratio: float=0.0,
|
||||
minlenratio: float=0.0,
|
||||
pre_beam_ratio: float=1.5,
|
||||
pre_beam_score_key: str="full", ) -> list:
|
||||
"""Perform beam search with scorers.
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): Encoded speech feature (T, D)
|
||||
sos (int): Start of sequence id
|
||||
eos (int): End of sequence id
|
||||
beam_size (int): The number of hypotheses kept during search
|
||||
vocab_size (int): The number of vocabulary
|
||||
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
||||
e.g., Decoder, CTCPrefixScorer, LM
|
||||
The scorer will be ignored if it is `None`
|
||||
weights (dict[str, float]): Dict of weights for each scorers
|
||||
The scorer will be ignored if its weight is 0
|
||||
token_list (list[str]): List of tokens for debug log
|
||||
maxlenratio (float): Input length ratio to obtain max output length.
|
||||
If maxlenratio=0.0 (default), it uses a end-detect function
|
||||
to automatically find maximum hypothesis lengths
|
||||
minlenratio (float): Input length ratio to obtain min output length.
|
||||
pre_beam_score_key (str): key of scores to perform pre-beam search
|
||||
pre_beam_ratio (float): beam size in the pre-beam search
|
||||
will be `int(pre_beam_ratio * beam_size)`
|
||||
|
||||
Returns:
|
||||
List[Dict]: N-best decoding results
|
||||
|
||||
"""
|
||||
ret = BeamSearch(
|
||||
scorers,
|
||||
weights,
|
||||
beam_size=beam_size,
|
||||
vocab_size=vocab_size,
|
||||
pre_beam_ratio=pre_beam_ratio,
|
||||
pre_beam_score_key=pre_beam_score_key,
|
||||
sos=sos,
|
||||
eos=eos,
|
||||
token_list=token_list, ).forward(
|
||||
x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
|
||||
return [h.asdict() for h in ret]
|
@ -0,0 +1,188 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
|
||||
import jsonlines
|
||||
import paddle
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from .beam_search import BatchBeamSearch
|
||||
from .beam_search import BeamSearch
|
||||
from .scorers.length_bonus import LengthBonus
|
||||
from .scorers.scorer_interface import BatchScorerInterface
|
||||
from .utils import add_results_to_json
|
||||
from deepspeech.exps import dynamic_import_tester
|
||||
from deepspeech.io.reader import LoadInputsAndTargets
|
||||
from deepspeech.models.asr_interface import ASRInterface
|
||||
from deepspeech.utils.log import Log
|
||||
# from espnet.asr.asr_utils import get_model_conf
|
||||
# from espnet.asr.asr_utils import torch_load
|
||||
# from espnet.nets.lm_interface import dynamic_import_lm
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
# NOTE: you need this func to generate our sphinx doc
|
||||
|
||||
|
||||
def load_trained_model(args):
|
||||
args.nprocs = args.ngpu
|
||||
confs = CfgNode()
|
||||
confs.set_new_allowed(True)
|
||||
confs.merge_from_file(args.model_conf)
|
||||
class_obj = dynamic_import_tester(args.model_name)
|
||||
exp = class_obj(confs, args)
|
||||
with exp.eval():
|
||||
exp.setup()
|
||||
exp.restore()
|
||||
char_list = exp.args.char_list
|
||||
model = exp.model
|
||||
return model, char_list, exp, confs
|
||||
|
||||
|
||||
def recog_v2(args):
|
||||
"""Decode with custom models that implements ScorerInterface.
|
||||
|
||||
Args:
|
||||
args (namespace): The program arguments.
|
||||
See py:func:`bin.asr_recog.get_parser` for details
|
||||
|
||||
"""
|
||||
logger.warning("experimental API for custom LMs is selected by --api v2")
|
||||
if args.batchsize > 1:
|
||||
raise NotImplementedError("multi-utt batch decoding is not implemented")
|
||||
if args.streaming_mode is not None:
|
||||
raise NotImplementedError("streaming mode is not implemented")
|
||||
if args.word_rnnlm:
|
||||
raise NotImplementedError("word LM is not implemented")
|
||||
|
||||
# set_deterministic(args)
|
||||
model, char_list, exp, confs = load_trained_model(args)
|
||||
assert isinstance(model, ASRInterface)
|
||||
|
||||
load_inputs_and_targets = LoadInputsAndTargets(
|
||||
mode="asr",
|
||||
load_output=False,
|
||||
sort_in_input_length=False,
|
||||
preprocess_conf=confs.collator.augmentation_config
|
||||
if args.preprocess_conf is None else args.preprocess_conf,
|
||||
preprocess_args={"train": False}, )
|
||||
|
||||
if args.rnnlm:
|
||||
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
|
||||
# NOTE: for a compatibility with less than 0.5.0 version models
|
||||
lm_model_module = getattr(lm_args, "model_module", "default")
|
||||
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
|
||||
lm = lm_class(len(char_list), lm_args)
|
||||
torch_load(args.rnnlm, lm)
|
||||
lm.eval()
|
||||
else:
|
||||
lm = None
|
||||
|
||||
if args.ngram_model:
|
||||
from .scorers.ngram import NgramFullScorer
|
||||
from .scorers.ngram import NgramPartScorer
|
||||
|
||||
if args.ngram_scorer == "full":
|
||||
ngram = NgramFullScorer(args.ngram_model, char_list)
|
||||
else:
|
||||
ngram = NgramPartScorer(args.ngram_model, char_list)
|
||||
else:
|
||||
ngram = None
|
||||
|
||||
scorers = model.scorers() # decoder
|
||||
scorers["lm"] = lm
|
||||
scorers["ngram"] = ngram
|
||||
scorers["length_bonus"] = LengthBonus(len(char_list))
|
||||
weights = dict(
|
||||
decoder=1.0 - args.ctc_weight,
|
||||
ctc=args.ctc_weight,
|
||||
lm=args.lm_weight,
|
||||
ngram=args.ngram_weight,
|
||||
length_bonus=args.penalty, )
|
||||
beam_search = BeamSearch(
|
||||
beam_size=args.beam_size,
|
||||
vocab_size=len(char_list),
|
||||
weights=weights,
|
||||
scorers=scorers,
|
||||
sos=model.sos,
|
||||
eos=model.eos,
|
||||
token_list=char_list,
|
||||
pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", )
|
||||
|
||||
# TODO(karita): make all scorers batchfied
|
||||
if args.batchsize == 1:
|
||||
non_batch = [
|
||||
k for k, v in beam_search.full_scorers.items()
|
||||
if not isinstance(v, BatchScorerInterface)
|
||||
]
|
||||
if len(non_batch) == 0:
|
||||
beam_search.__class__ = BatchBeamSearch
|
||||
logger.info("BatchBeamSearch implementation is selected.")
|
||||
else:
|
||||
logger.warning(f"As non-batch scorers {non_batch} are found, "
|
||||
f"fall back to non-batch implementation.")
|
||||
|
||||
if args.ngpu > 1:
|
||||
raise NotImplementedError("only single GPU decoding is supported")
|
||||
if args.ngpu == 1:
|
||||
device = "gpu:0"
|
||||
else:
|
||||
device = "cpu"
|
||||
paddle.set_device(device)
|
||||
dtype = getattr(paddle, args.dtype)
|
||||
logger.info(f"Decoding device={device}, dtype={dtype}")
|
||||
model.to(device=device, dtype=dtype)
|
||||
model.eval()
|
||||
beam_search.to(device=device, dtype=dtype)
|
||||
beam_search.eval()
|
||||
|
||||
# read json data
|
||||
js = []
|
||||
with jsonlines.open(args.recog_json, "r") as reader:
|
||||
for item in reader:
|
||||
js.append(item)
|
||||
# jsonlines to dict, key by 'utt', value by jsonline
|
||||
js = {item['utt']: item for item in js}
|
||||
|
||||
new_js = {}
|
||||
with paddle.no_grad():
|
||||
with jsonlines.open(args.result_label, "w") as f:
|
||||
for idx, name in enumerate(js.keys(), 1):
|
||||
logger.info(f"({idx}/{len(js.keys())}) decoding " + name)
|
||||
batch = [(name, js[name])]
|
||||
feat = load_inputs_and_targets(batch)[0][0]
|
||||
logger.info(f'feat: {feat.shape}')
|
||||
enc = model.encode(paddle.to_tensor(feat).to(dtype))
|
||||
logger.info(f'eout: {enc.shape}')
|
||||
nbest_hyps = beam_search(
|
||||
x=enc,
|
||||
maxlenratio=args.maxlenratio,
|
||||
minlenratio=args.minlenratio)
|
||||
nbest_hyps = [
|
||||
h.asdict()
|
||||
for h in nbest_hyps[:min(len(nbest_hyps), args.nbest)]
|
||||
]
|
||||
new_js[name] = add_results_to_json(js[name], nbest_hyps,
|
||||
char_list)
|
||||
|
||||
item = new_js[name]['output'][0] # 1-best
|
||||
ref = item['text']
|
||||
rec_text = item['rec_text'].replace('▁', ' ').replace(
|
||||
'<eos>', '').strip()
|
||||
rec_tokenid = list(map(int, item['rec_tokenid'].split()))
|
||||
f.write({
|
||||
"utt": name,
|
||||
"refs": [ref],
|
||||
"hyps": [rec_text],
|
||||
"hyps_tokenid": [rec_tokenid],
|
||||
})
|
@ -0,0 +1,376 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""End-to-end speech recognition model decoding script."""
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from distutils.util import strtobool
|
||||
|
||||
import configargparse
|
||||
import numpy as np
|
||||
|
||||
from deepspeech.decoders.recog import recog_v2
|
||||
|
||||
|
||||
def get_parser():
|
||||
"""Get default arguments."""
|
||||
parser = configargparse.ArgumentParser(
|
||||
description="Transcribe text from speech using "
|
||||
"a speech recognition model on one CPU or GPU",
|
||||
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
||||
formatter_class=configargparse.ArgumentDefaultsHelpFormatter, )
|
||||
parser.add(
|
||||
'--model-name',
|
||||
type=str,
|
||||
default='u2_kaldi',
|
||||
help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
|
||||
# general configuration
|
||||
parser.add("--config", is_config_file=True, help="Config file path")
|
||||
parser.add(
|
||||
"--config2",
|
||||
is_config_file=True,
|
||||
help="Second config file path that overwrites the settings in `--config`",
|
||||
)
|
||||
parser.add(
|
||||
"--config3",
|
||||
is_config_file=True,
|
||||
help="Third config file path that overwrites the settings "
|
||||
"in `--config` and `--config2`", )
|
||||
|
||||
parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
choices=("float16", "float32", "float64"),
|
||||
default="float32",
|
||||
help="Float precision (only available in --api v2)", )
|
||||
parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
|
||||
parser.add_argument("--seed", type=int, default=1, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--verbose", "-V", type=int, default=2, help="Verbose option")
|
||||
parser.add_argument(
|
||||
"--batchsize",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size for beam search (0: means no batch processing)", )
|
||||
parser.add_argument(
|
||||
"--preprocess-conf",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The configuration file for the pre-processing", )
|
||||
parser.add_argument(
|
||||
"--api",
|
||||
default="v2",
|
||||
choices=["v2"],
|
||||
help="Beam search APIs "
|
||||
"v2: Experimental API. It supports any models that implements ScorerInterface.",
|
||||
)
|
||||
# task related
|
||||
parser.add_argument(
|
||||
"--recog-json", type=str, help="Filename of recognition data (json)")
|
||||
parser.add_argument(
|
||||
"--result-label",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Filename of result label data (json)", )
|
||||
# model (parameter) related
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model file parameters to read")
|
||||
parser.add_argument(
|
||||
"--model-conf", type=str, default=None, help="Model config file")
|
||||
parser.add_argument(
|
||||
"--num-spkrs",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[1, 2],
|
||||
help="Number of speakers in the speech", )
|
||||
parser.add_argument(
|
||||
"--num-encs",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Number of encoders in the model.")
|
||||
# search related
|
||||
parser.add_argument(
|
||||
"--nbest", type=int, default=1, help="Output N-best hypotheses")
|
||||
parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
|
||||
parser.add_argument(
|
||||
"--penalty", type=float, default=0.0, help="Incertion penalty")
|
||||
parser.add_argument(
|
||||
"--maxlenratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""Input length ratio to obtain max output length.
|
||||
If maxlenratio=0.0 (default), it uses a end-detect function
|
||||
to automatically find maximum hypothesis lengths.
|
||||
If maxlenratio<0.0, its absolute value is interpreted
|
||||
as a constant max output length""", )
|
||||
parser.add_argument(
|
||||
"--minlenratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Input length ratio to obtain min output length", )
|
||||
parser.add_argument(
|
||||
"--ctc-weight",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="CTC weight in joint decoding")
|
||||
parser.add_argument(
|
||||
"--weights-ctc-dec",
|
||||
type=float,
|
||||
action="append",
|
||||
help="ctc weight assigned to each encoder during decoding."
|
||||
"[in multi-encoder mode only]", )
|
||||
parser.add_argument(
|
||||
"--ctc-window-margin",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Use CTC window with margin parameter to accelerate
|
||||
CTC/attention decoding especially on GPU. Smaller magin
|
||||
makes decoding faster, but may increase search errors.
|
||||
If margin=0 (default), this function is disabled""", )
|
||||
# transducer related
|
||||
parser.add_argument(
|
||||
"--search-type",
|
||||
type=str,
|
||||
default="default",
|
||||
choices=["default", "nsc", "tsd", "alsd", "maes"],
|
||||
help="""Type of beam search implementation to use during inference.
|
||||
Can be either: default beam search ("default"),
|
||||
N-Step Constrained beam search ("nsc"), Time-Synchronous Decoding ("tsd"),
|
||||
Alignment-Length Synchronous Decoding ("alsd") or
|
||||
modified Adaptive Expansion Search ("maes").""", )
|
||||
parser.add_argument(
|
||||
"--nstep",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Number of expansion steps allowed in NSC beam search or mAES
|
||||
(nstep > 0 for NSC and nstep > 1 for mAES).""", )
|
||||
parser.add_argument(
|
||||
"--prefix-alpha",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Length prefix difference allowed in NSC beam search or mAES.", )
|
||||
parser.add_argument(
|
||||
"--max-sym-exp",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of symbol expansions allowed in TSD.", )
|
||||
parser.add_argument(
|
||||
"--u-max",
|
||||
type=int,
|
||||
default=400,
|
||||
help="Length prefix difference allowed in ALSD.", )
|
||||
parser.add_argument(
|
||||
"--expansion-gamma",
|
||||
type=float,
|
||||
default=2.3,
|
||||
help="Allowed logp difference for prune-by-value method in mAES.", )
|
||||
parser.add_argument(
|
||||
"--expansion-beta",
|
||||
type=int,
|
||||
default=2,
|
||||
help="""Number of additional candidates for expanded hypotheses
|
||||
selection in mAES.""", )
|
||||
parser.add_argument(
|
||||
"--score-norm",
|
||||
type=strtobool,
|
||||
nargs="?",
|
||||
default=True,
|
||||
help="Normalize final hypotheses' score by length", )
|
||||
parser.add_argument(
|
||||
"--softmax-temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Penalization term for softmax function.", )
|
||||
# rnnlm related
|
||||
parser.add_argument(
|
||||
"--rnnlm", type=str, default=None, help="RNNLM model file to read")
|
||||
parser.add_argument(
|
||||
"--rnnlm-conf",
|
||||
type=str,
|
||||
default=None,
|
||||
help="RNNLM model config file to read")
|
||||
parser.add_argument(
|
||||
"--word-rnnlm",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Word RNNLM model file to read")
|
||||
parser.add_argument(
|
||||
"--word-rnnlm-conf",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Word RNNLM model config file to read", )
|
||||
parser.add_argument(
|
||||
"--word-dict", type=str, default=None, help="Word list to read")
|
||||
parser.add_argument(
|
||||
"--lm-weight", type=float, default=0.1, help="RNNLM weight")
|
||||
# ngram related
|
||||
parser.add_argument(
|
||||
"--ngram-model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="ngram model file to read")
|
||||
parser.add_argument(
|
||||
"--ngram-weight", type=float, default=0.1, help="ngram weight")
|
||||
parser.add_argument(
|
||||
"--ngram-scorer",
|
||||
type=str,
|
||||
default="part",
|
||||
choices=("full", "part"),
|
||||
help="""if the ngram is set as a part scorer, similar with CTC scorer,
|
||||
ngram scorer only scores topK hypethesis.
|
||||
if the ngram is set as full scorer, ngram scorer scores all hypthesis
|
||||
the decoding speed of part scorer is musch faster than full one""",
|
||||
)
|
||||
# streaming related
|
||||
parser.add_argument(
|
||||
"--streaming-mode",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["window", "segment"],
|
||||
help="""Use streaming recognizer for inference.
|
||||
`--batchsize` must be set to 0 to enable this mode""", )
|
||||
parser.add_argument(
|
||||
"--streaming-window", type=int, default=10, help="Window size")
|
||||
parser.add_argument(
|
||||
"--streaming-min-blank-dur",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Minimum blank duration threshold", )
|
||||
parser.add_argument(
|
||||
"--streaming-onset-margin", type=int, default=1, help="Onset margin")
|
||||
parser.add_argument(
|
||||
"--streaming-offset-margin", type=int, default=1, help="Offset margin")
|
||||
# non-autoregressive related
|
||||
# Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail.
|
||||
parser.add_argument(
|
||||
"--maskctc-n-iterations",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of decoding iterations."
|
||||
"For Mask CTC, set 0 to predict 1 mask/iter.", )
|
||||
parser.add_argument(
|
||||
"--maskctc-probability-threshold",
|
||||
type=float,
|
||||
default=0.999,
|
||||
help="Threshold probability for CTC output", )
|
||||
# quantize model related
|
||||
parser.add_argument(
|
||||
"--quantize-config",
|
||||
nargs="*",
|
||||
help="Quantize config list. E.g.: --quantize-config=[Linear,LSTM,GRU]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantize-dtype",
|
||||
type=str,
|
||||
default="qint8",
|
||||
help="Dtype dynamic quantize")
|
||||
parser.add_argument(
|
||||
"--quantize-asr-model",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Quantize asr model", )
|
||||
parser.add_argument(
|
||||
"--quantize-lm-model",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Quantize lm model", )
|
||||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
"""Run the main decoding function."""
|
||||
parser = get_parser()
|
||||
parser.add_argument(
|
||||
"--output", metavar="CKPT_DIR", help="path to save checkpoint.")
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, help="path to load checkpoint")
|
||||
parser.add_argument("--dict-path", type=str, help="path to load checkpoint")
|
||||
args = parser.parse_args(args)
|
||||
|
||||
if args.ngpu == 0 and args.dtype == "float16":
|
||||
raise ValueError(
|
||||
f"--dtype {args.dtype} does not support the CPU backend.")
|
||||
|
||||
# logging info
|
||||
if args.verbose == 1:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
elif args.verbose == 2:
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
else:
|
||||
logging.basicConfig(
|
||||
level=logging.WARN,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
logging.warning("Skip DEBUG/INFO messages")
|
||||
logging.info(args)
|
||||
|
||||
# check CUDA_VISIBLE_DEVICES
|
||||
if args.ngpu > 0:
|
||||
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
|
||||
if cvd is None:
|
||||
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
|
||||
elif args.ngpu != len(cvd.split(",")):
|
||||
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
|
||||
sys.exit(1)
|
||||
|
||||
# TODO(mn5k): support of multiple GPUs
|
||||
if args.ngpu > 1:
|
||||
logging.error("The program only supports ngpu=1.")
|
||||
sys.exit(1)
|
||||
|
||||
# display PYTHONPATH
|
||||
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
|
||||
|
||||
# seed setting
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
logging.info("set random seed = %d" % args.seed)
|
||||
|
||||
# validate rnn options
|
||||
if args.rnnlm is not None and args.word_rnnlm is not None:
|
||||
logging.error(
|
||||
"It seems that both --rnnlm and --word-rnnlm are specified. "
|
||||
"Please use either option.")
|
||||
sys.exit(1)
|
||||
|
||||
# recog
|
||||
if args.num_spkrs == 1:
|
||||
if args.num_encs == 1:
|
||||
# Experimental API that supports custom LMs
|
||||
if args.api == "v2":
|
||||
|
||||
recog_v2(args)
|
||||
else:
|
||||
raise ValueError("Only support --api v2")
|
||||
else:
|
||||
if args.api == "v2":
|
||||
raise NotImplementedError(
|
||||
f"--num-encs {args.num_encs} > 1 is not supported in --api v2"
|
||||
)
|
||||
elif args.num_spkrs == 2:
|
||||
raise ValueError("asr_mix not supported.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:])
|
@ -0,0 +1,187 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluation for U2 model."""
|
||||
import cProfile
|
||||
import os
|
||||
import sys
|
||||
|
||||
import paddle
|
||||
import soundfile
|
||||
|
||||
from deepspeech.exps.u2.config import get_cfg_defaults
|
||||
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||
from deepspeech.io.collator import SpeechCollator
|
||||
from deepspeech.models.u2 import U2Model
|
||||
from deepspeech.training.cli import default_argument_parser
|
||||
from deepspeech.training.trainer import Trainer
|
||||
from deepspeech.utils import layer_tools
|
||||
from deepspeech.utils import mp_tools
|
||||
from deepspeech.utils.log import Log
|
||||
from deepspeech.utils.utility import print_arguments
|
||||
from deepspeech.utils.utility import UpdateConfig
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
# TODO(hui zhang): dynamic load
|
||||
|
||||
|
||||
class U2Tester_Hub(Trainer):
|
||||
def __init__(self, config, args):
|
||||
# super().__init__(config, args)
|
||||
self.args = args
|
||||
self.config = config
|
||||
self.audio_file = args.audio_file
|
||||
self.collate_fn_test = SpeechCollator.from_config(config)
|
||||
self._text_featurizer = TextFeaturizer(
|
||||
unit_type=config.collator.unit_type,
|
||||
vocab_filepath=None,
|
||||
spm_model_prefix=config.collator.spm_model_prefix)
|
||||
|
||||
def setup_model(self):
|
||||
config = self.config
|
||||
model_conf = config.model
|
||||
|
||||
with UpdateConfig(model_conf):
|
||||
model_conf.input_dim = self.collate_fn_test.feature_size
|
||||
model_conf.output_dim = self.collate_fn_test.vocab_size
|
||||
|
||||
model = U2Model.from_config(model_conf)
|
||||
|
||||
if self.parallel:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
logger.info(f"{model}")
|
||||
layer_tools.print_params(model, logger.info)
|
||||
|
||||
self.model = model
|
||||
logger.info("Setup model")
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def test(self):
|
||||
self.model.eval()
|
||||
cfg = self.config.decoding
|
||||
audio_file = self.audio_file
|
||||
collate_fn_test = self.collate_fn_test
|
||||
audio, _ = collate_fn_test.process_utterance(
|
||||
audio_file=audio_file, transcript="Hello")
|
||||
audio_len = audio.shape[0]
|
||||
audio = paddle.to_tensor(audio, dtype='float32')
|
||||
audio_len = paddle.to_tensor(audio_len)
|
||||
audio = paddle.unsqueeze(audio, axis=0)
|
||||
vocab_list = collate_fn_test.vocab_list
|
||||
|
||||
text_feature = self.collate_fn_test.text_feature
|
||||
result_transcripts = self.model.decode(
|
||||
audio,
|
||||
audio_len,
|
||||
text_feature=text_feature,
|
||||
decoding_method=cfg.decoding_method,
|
||||
lang_model_path=cfg.lang_model_path,
|
||||
beam_alpha=cfg.alpha,
|
||||
beam_beta=cfg.beta,
|
||||
beam_size=cfg.beam_size,
|
||||
cutoff_prob=cfg.cutoff_prob,
|
||||
cutoff_top_n=cfg.cutoff_top_n,
|
||||
num_processes=cfg.num_proc_bsearch,
|
||||
ctc_weight=cfg.ctc_weight,
|
||||
decoding_chunk_size=cfg.decoding_chunk_size,
|
||||
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
|
||||
simulate_streaming=cfg.simulate_streaming)
|
||||
logger.info("The result_transcripts: " + result_transcripts[0][0])
|
||||
|
||||
def run_test(self):
|
||||
self.resume()
|
||||
try:
|
||||
self.test()
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
||||
def setup(self):
|
||||
"""Setup the experiment.
|
||||
"""
|
||||
paddle.set_device('gpu' if self.args.nprocs > 0 else 'cpu')
|
||||
|
||||
#self.setup_output_dir()
|
||||
#self.setup_checkpointer()
|
||||
|
||||
#self.setup_dataloader()
|
||||
self.setup_model()
|
||||
|
||||
self.iteration = 0
|
||||
self.epoch = 0
|
||||
|
||||
def resume(self):
|
||||
"""Resume from the checkpoint at checkpoints in the output
|
||||
directory or load a specified checkpoint.
|
||||
"""
|
||||
params_path = self.args.checkpoint_path + ".pdparams"
|
||||
model_dict = paddle.load(params_path)
|
||||
self.model.set_state_dict(model_dict)
|
||||
|
||||
|
||||
def check(audio_file):
|
||||
logger.info("checking the audio file format......")
|
||||
try:
|
||||
sig, sample_rate = soundfile.read(audio_file)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
logger.error(
|
||||
"can not open the wav file, please check the audio file format")
|
||||
sys.exit(-1)
|
||||
logger.info("The sample rate is %d" % sample_rate)
|
||||
assert (sample_rate == 16000)
|
||||
logger.info("The audio file format is right")
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = U2Tester_Hub(config, args)
|
||||
with exp.eval():
|
||||
exp.setup()
|
||||
exp.run_test()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
# save asr result to
|
||||
parser.add_argument(
|
||||
"--result_file", type=str, help="path of save the asr result")
|
||||
parser.add_argument(
|
||||
"--audio_file", type=str, help="path of the input audio file")
|
||||
args = parser.parse_args()
|
||||
print_arguments(args, globals())
|
||||
|
||||
if not os.path.isfile(args.audio_file):
|
||||
print("Please input the right audio file path")
|
||||
sys.exit(-1)
|
||||
check(args.audio_file)
|
||||
# https://yaml.org/type/float.html
|
||||
config = get_cfg_defaults()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
if args.dump_config:
|
||||
with open(args.dump_config, 'w') as f:
|
||||
print(config, file=f)
|
||||
|
||||
# Setting for profiling
|
||||
pr = cProfile.Profile()
|
||||
pr.runcall(main, config, args)
|
||||
pr.dump_stats('test.profile')
|
@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
from deepspeech.decoders.recog_bin import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:])
|
@ -0,0 +1,161 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""ASR Interface module."""
|
||||
import argparse
|
||||
|
||||
from deepspeech.utils.dynamic_import import dynamic_import
|
||||
|
||||
|
||||
class ASRInterface:
|
||||
"""ASR Interface for ESPnet model implementation."""
|
||||
|
||||
@staticmethod
|
||||
def add_arguments(parser):
|
||||
"""Add arguments to parser."""
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def build(cls, idim: int, odim: int, **kwargs):
|
||||
"""Initialize this class with python-level args.
|
||||
|
||||
Args:
|
||||
idim (int): The number of an input feature dim.
|
||||
odim (int): The number of output vocab.
|
||||
|
||||
Returns:
|
||||
ASRinterface: A new instance of ASRInterface.
|
||||
|
||||
"""
|
||||
args = argparse.Namespace(**kwargs)
|
||||
return cls(idim, odim, args)
|
||||
|
||||
def forward(self, xs, ilens, ys, olens):
|
||||
"""Compute loss for training.
|
||||
|
||||
:param xs: batch of padded source sequences paddle.Tensor (B, Tmax, idim)
|
||||
:param ilens: batch of lengths of source sequences (B), paddle.Tensor
|
||||
:param ys: batch of padded target sequences paddle.Tensor (B, Lmax)
|
||||
:param olens: batch of lengths of target sequences (B), paddle.Tensor
|
||||
:return: loss value
|
||||
:rtype: paddle.Tensor
|
||||
"""
|
||||
raise NotImplementedError("forward method is not implemented")
|
||||
|
||||
def recognize(self, x, recog_args, char_list=None, rnnlm=None):
|
||||
"""Recognize x for evaluation.
|
||||
|
||||
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
|
||||
:param namespace recog_args: argment namespace contraining options
|
||||
:param list char_list: list of characters
|
||||
:param paddle.nn.Layer rnnlm: language model module
|
||||
:return: N-best decoding results
|
||||
:rtype: list
|
||||
"""
|
||||
raise NotImplementedError("recognize method is not implemented")
|
||||
|
||||
def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None):
|
||||
"""Beam search implementation for batch.
|
||||
|
||||
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
|
||||
:param namespace recog_args: argument namespace containing options
|
||||
:param list char_list: list of characters
|
||||
:param paddle.nn.Module rnnlm: language model module
|
||||
:return: N-best decoding results
|
||||
:rtype: list
|
||||
"""
|
||||
raise NotImplementedError("Batch decoding is not supported yet.")
|
||||
|
||||
def calculate_all_attentions(self, xs, ilens, ys):
|
||||
"""Calculate attention.
|
||||
|
||||
:param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...]
|
||||
:param ndarray ilens: batch of lengths of input sequences (B)
|
||||
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
|
||||
:return: attention weights (B, Lmax, Tmax)
|
||||
:rtype: float ndarray
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"calculate_all_attentions method is not implemented")
|
||||
|
||||
def calculate_all_ctc_probs(self, xs, ilens, ys):
|
||||
"""Calculate CTC probability.
|
||||
|
||||
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
|
||||
:param ndarray ilens: batch of lengths of input sequences (B)
|
||||
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
|
||||
:return: CTC probabilities (B, Tmax, vocab)
|
||||
:rtype: float ndarray
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"calculate_all_ctc_probs method is not implemented")
|
||||
|
||||
@property
|
||||
def attention_plot_class(self):
|
||||
"""Get attention plot class."""
|
||||
from espnet.asr.asr_utils import PlotAttentionReport
|
||||
|
||||
return PlotAttentionReport
|
||||
|
||||
@property
|
||||
def ctc_plot_class(self):
|
||||
"""Get CTC plot class."""
|
||||
from espnet.asr.asr_utils import PlotCTCReport
|
||||
|
||||
return PlotCTCReport
|
||||
|
||||
def get_total_subsampling_factor(self):
|
||||
"""Get total subsampling factor."""
|
||||
raise NotImplementedError(
|
||||
"get_total_subsampling_factor method is not implemented")
|
||||
|
||||
def encode(self, feat):
|
||||
"""Encode feature in `beam_search` (optional).
|
||||
|
||||
Args:
|
||||
x (numpy.ndarray): input feature (T, D)
|
||||
Returns:
|
||||
paddle.Tensor: encoded feature (T, D)
|
||||
"""
|
||||
raise NotImplementedError("encode method is not implemented")
|
||||
|
||||
def scorers(self):
|
||||
"""Get scorers for `beam_search` (optional).
|
||||
|
||||
Returns:
|
||||
dict[str, ScorerInterface]: dict of `ScorerInterface` objects
|
||||
|
||||
"""
|
||||
raise NotImplementedError("decoders method is not implemented")
|
||||
|
||||
|
||||
predefined_asr = {
|
||||
"transformer": "deepspeech.models.u2:U2Model",
|
||||
"conformer": "deepspeech.models.u2:U2Model",
|
||||
}
|
||||
|
||||
|
||||
def dynamic_import_asr(module):
|
||||
"""Import ASR models dynamically.
|
||||
|
||||
Args:
|
||||
module (str): asr name. e.g., transformer, conformer
|
||||
|
||||
Returns:
|
||||
type: ASR class
|
||||
|
||||
"""
|
||||
model_class = dynamic_import(module, predefined_asr)
|
||||
assert issubclass(model_class,
|
||||
ASRInterface), f"{module} does not implement ASRInterface"
|
||||
return model_class
|
@ -0,0 +1,47 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 3 ];then
|
||||
echo "usage: ${0} config_path ckpt_path_prefix audio_file"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
ckpt_prefix=$2
|
||||
audio_file=$3
|
||||
|
||||
chunk_mode=false
|
||||
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
|
||||
chunk_mode=true
|
||||
fi
|
||||
|
||||
# download language model
|
||||
#bash local/download_lm_ch.sh
|
||||
#if [ $? -ne 0 ]; then
|
||||
# exit 1
|
||||
#fi
|
||||
|
||||
|
||||
|
||||
for type in attention_rescoring; do
|
||||
echo "decoding ${type}"
|
||||
batch_size=1
|
||||
output_dir=${ckpt_prefix}
|
||||
mkdir -p ${output_dir}
|
||||
python3 -u ${BIN_DIR}/test_hub.py \
|
||||
--nproc ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--result_file ${output_dir}/${type}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--opts decoding.decoding_method ${type} \
|
||||
--opts decoding.batch_size ${batch_size} \
|
||||
--audio_file ${audio_file}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
exit 0
|
@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 4 ];then
|
||||
echo "usage: ${0} config_path ckpt_path_prefix model_type audio_file"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
ckpt_prefix=$2
|
||||
model_type=$3
|
||||
audio_file=$4
|
||||
|
||||
# download language model
|
||||
bash local/download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
python3 -u ${BIN_DIR}/test_hub.py \
|
||||
--nproc ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--result_file ${ckpt_prefix}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--model_type ${model_type} \
|
||||
--audio_file ${audio_file}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,54 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 3 ];then
|
||||
echo "usage: ${0} config_path ckpt_path_prefix audio_file"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
ckpt_prefix=$2
|
||||
audio_file=$3
|
||||
|
||||
# bpemode (unigram or bpe)
|
||||
nbpe=5000
|
||||
bpemode=unigram
|
||||
bpeprefix="data/bpe_${bpemode}_${nbpe}"
|
||||
bpemodel=${bpeprefix}.model
|
||||
|
||||
chunk_mode=false
|
||||
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
|
||||
chunk_mode=true
|
||||
fi
|
||||
|
||||
# download language model
|
||||
#bash local/download_lm_ch.sh
|
||||
#if [ $? -ne 0 ]; then
|
||||
# exit 1
|
||||
#fi
|
||||
|
||||
|
||||
for type in attention_rescoring; do
|
||||
echo "decoding ${type}"
|
||||
batch_size=1
|
||||
output_dir=${ckpt_prefix}
|
||||
mkdir -p ${output_dir}
|
||||
python3 -u ${BIN_DIR}/test_hub.py \
|
||||
--nproc ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--result_file ${output_dir}/${type}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--opts decoding.decoding_method ${type} \
|
||||
--opts decoding.batch_size ${batch_size} \
|
||||
--audio_file ${audio_file}
|
||||
|
||||
#score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true ${expdir}/${decode_dir} ${dict}
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
exit 0
|
@ -1,9 +1,14 @@
|
||||
# LibriSpeech
|
||||
|
||||
## Transformer
|
||||
| Model | Params | Config | Augmentation| Test Set | Decode Method | Loss | WER % |
|
||||
| --- | --- | --- | --- | --- | --- | --- | --- |
|
||||
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 6.395054340362549 | 4.2 |
|
||||
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.395054340362549 | 5.0 |
|
||||
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.395054340362549 | |
|
||||
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescore | 6.395054340362549 | |
|
||||
| Model | Params | Config | Augmentation| Loss |
|
||||
| --- | --- | --- | --- |
|
||||
| transformer | 32.52 M | conf/transformer.yaml | spec_aug | 6.3197922706604 |
|
||||
|
||||
|
||||
| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err |
|
||||
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
||||
| test-clean | attention | 2620 | 52576 | 96.4 | 2.5 | 1.1 | 0.4 | 4.0 | 34.7 |
|
||||
| test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 |
|
||||
| test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 |
|
||||
| test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 |
|
||||
| test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 |
|
||||
|
@ -0,0 +1,7 @@
|
||||
batchsize: 0
|
||||
beam-size: 60
|
||||
ctc-weight: 0.0
|
||||
lm-weight: 0.0
|
||||
maxlenratio: 0.0
|
||||
minlenratio: 0.0
|
||||
penalty: 0.0
|
@ -0,0 +1,7 @@
|
||||
batchsize: 0
|
||||
beam-size: 60
|
||||
ctc-weight: 0.4
|
||||
lm-weight: 0.6
|
||||
maxlenratio: 0.0
|
||||
minlenratio: 0.0
|
||||
penalty: 0.0
|
@ -0,0 +1,7 @@
|
||||
batchsize: 0
|
||||
beam-size: 60
|
||||
ctc-weight: 0.4
|
||||
lm-weight: 0.0
|
||||
maxlenratio: 0.0
|
||||
minlenratio: 0.0
|
||||
penalty: 0.0
|
@ -0,0 +1,109 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
expdir=exp
|
||||
datadir=data
|
||||
nj=32
|
||||
tag=
|
||||
|
||||
# decode config
|
||||
decode_config=conf/decode/decode.yaml
|
||||
|
||||
# lm params
|
||||
lang_model=rnnlm.model.best
|
||||
lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/
|
||||
lmtag='nolm'
|
||||
|
||||
recog_set="test-clean test-other dev-clean dev-other"
|
||||
recog_set="test-clean"
|
||||
|
||||
# bpemode (unigram or bpe)
|
||||
nbpe=5000
|
||||
bpemode=unigram
|
||||
bpeprefix="data/bpe_${bpemode}_${nbpe}"
|
||||
bpemodel=${bpeprefix}.model
|
||||
|
||||
# bin params
|
||||
config_path=conf/transformer.yaml
|
||||
dict=data/bpe_unigram_5000_units.txt
|
||||
ckpt_prefix=
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
if [ -z ${ckpt_prefix} ]; then
|
||||
echo "usage: $0 --ckpt_prefix ckpt_prefix"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
ckpt_dir=$(dirname `dirname ${ckpt_prefix}`)
|
||||
echo "ckpt dir: ${ckpt_dir}"
|
||||
|
||||
ckpt_tag=$(basename ${ckpt_prefix})
|
||||
echo "ckpt tag: ${ckpt_tag}"
|
||||
|
||||
chunk_mode=false
|
||||
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
|
||||
chunk_mode=true
|
||||
fi
|
||||
echo "chunk mode: ${chunk_mode}"
|
||||
echo "decode conf: ${decode_config}"
|
||||
|
||||
# download language model
|
||||
#bash local/download_lm_en.sh
|
||||
#if [ $? -ne 0 ]; then
|
||||
# exit 1
|
||||
#fi
|
||||
|
||||
|
||||
pids=() # initialize pids
|
||||
|
||||
for dmethd in join_ctc; do
|
||||
(
|
||||
echo "${dmethd} decoding"
|
||||
for rtask in ${recog_set}; do
|
||||
(
|
||||
echo "${rtask} dataset"
|
||||
decode_dir=${ckpt_dir}/decode/decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag}_${ckpt_tag}_${tag}
|
||||
feat_recog_dir=${datadir}
|
||||
mkdir -p ${decode_dir}
|
||||
mkdir -p ${feat_recog_dir}
|
||||
|
||||
# split data
|
||||
split_json.sh manifest.${rtask} ${nj}
|
||||
|
||||
#### use CPU for decoding
|
||||
ngpu=0
|
||||
|
||||
# set batchsize 0 to disable batch decoding
|
||||
${decode_cmd} JOB=1:${nj} ${decode_dir}/log/decode.JOB.log \
|
||||
python3 -u ${BIN_DIR}/recog.py \
|
||||
--api v2 \
|
||||
--config ${decode_config} \
|
||||
--ngpu ${ngpu} \
|
||||
--batchsize 0 \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--dict-path ${dict} \
|
||||
--recog-json ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} \
|
||||
--result-label ${decode_dir}/data.JOB.json \
|
||||
--model-conf ${config_path} \
|
||||
--model ${ckpt_prefix}.pdparams
|
||||
|
||||
#--rnnlm ${lmexpdir}/${lang_model} \
|
||||
|
||||
score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict}
|
||||
|
||||
) &
|
||||
pids+=($!) # store background pids
|
||||
i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done
|
||||
[ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." || true
|
||||
done
|
||||
)
|
||||
done
|
||||
|
||||
echo "Finished"
|
||||
|
||||
exit 0
|
@ -1,3 +1,3 @@
|
||||
# Punctation Restoration
|
||||
|
||||
Please using [PaddleSpeechTask](https://github.com/745165806/PaddleSpeechTask] to do this task.
|
||||
Please using [PaddleSpeechTask](https://github.com/745165806/PaddleSpeechTask) to do this task.
|
||||
|
Loading…
Reference in new issue