[ASR] Support CTC decoder online (#821)

* fix the destructer problem for prefixes

* unified offline and online in ctcdecoders, test=asr

* rename swig_decoders to paddlespeech_ctcdecoders, test=asr

* add reset_stage for ctcdecoder

* fix some problems

* fix ctconline

* fix a bug

* fix the format

* fix 1xt2x
pull/1380/head
Jackwaterveg 2 years ago committed by GitHub
parent 3dedea8582
commit d7222c0453
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -162,39 +162,17 @@ class DeepSpeech2Model(nn.Layer):
return loss
@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
def decode(self, audio, audio_len):
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
# Make sure the decoder has been initialized
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
print("probs.shape", probs.shape)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
def decode_probs_split(self, probs_split, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_prob, cutoff_top_n, num_processes):
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
return self.decoder.decode_probs_split(
probs_split, vocab_list, decoding_method, lang_model_path,
beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n,
num_processes)
batch_size = probs.shape[0]
self.decoder.reset_decoder(batch_size = batch_size)
self.decoder.next(probs, eouts_len)
trans_best, trans_beam = self.decoder.decode()
return trans_best
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):

@ -254,12 +254,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
vocab_list = self.test_loader.collate_fn.vocab_list
target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.compute_result_transcripts(audio, audio_len,
vocab_list, cfg)
result_transcripts = self.compute_result_transcripts(audio, audio_len)
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
@ -280,19 +278,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
def compute_result_transcripts(self, audio, audio_len):
result_transcripts = self.model.decode(audio, audio_len)
result_transcripts = [
self._text_featurizer.detokenize(item)
for item in result_transcripts
@ -307,6 +295,17 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cfg = self.config
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
# Initialized the decoder in model
decode_cfg = self.config.decode
vocab_list = self.test_loader.collate_fn.vocab_list
decode_batch_size = self.test_loader.batch_size
self.model.decoder.init_decoder(
decode_batch_size, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
@ -326,6 +325,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
self.model.decoder.del_decoder()
def run_test(self):
self.resume_or_scratch()

@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .swig_wrapper import ctc_beam_search_decoding
from .swig_wrapper import ctc_beam_search_decoding_batch
from .swig_wrapper import ctc_greedy_decoding
from .swig_wrapper import CTCBeamSearchDecoder
from .swig_wrapper import Scorer

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper for various CTC decoders in SWIG."""
import swig_decoders
import paddlespeech_ctcdecoders
class Scorer(swig_decoders.Scorer):
class Scorer(paddlespeech_ctcdecoders.Scorer):
"""Wrapper for Scorer.
:param alpha: Parameter associated with language model. Don't use
@ -26,14 +26,17 @@ class Scorer(swig_decoders.Scorer):
:type beta: float
:model_path: Path to load language model.
:type model_path: str
:param vocabulary: Vocabulary list.
:type vocabulary: list
"""
def __init__(self, alpha, beta, model_path, vocabulary):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
paddlespeech_ctcdecoders.Scorer.__init__(self, alpha, beta, model_path,
vocabulary)
def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
"""Wrapper for ctc best path decoder in swig.
def ctc_greedy_decoding(probs_seq, vocabulary, blank_id):
"""Wrapper for ctc best path decodeing function in swig.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
@ -44,19 +47,19 @@ def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
:return: Decoding result string.
:rtype: str
"""
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary,
blank_id)
result = paddlespeech_ctcdecoders.ctc_greedy_decoding(probs_seq.tolist(),
vocabulary, blank_id)
return result
def ctc_beam_search_decoder(probs_seq,
vocabulary,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the CTC Beam Search Decoder.
def ctc_beam_search_decoding(probs_seq,
vocabulary,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the CTC Beam Search Decoding function.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
@ -81,22 +84,22 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the probability.
:rtype: list
"""
beam_results = swig_decoders.ctc_beam_search_decoder(
beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoding(
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
ext_scoring_func, blank_id)
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
return beam_results
def ctc_beam_search_decoder_batch(probs_split,
vocabulary,
beam_size,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the batched CTC beam search decoder.
def ctc_beam_search_decoding_batch(probs_split,
vocabulary,
beam_size,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the batched CTC beam search decodeing batch function.
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
@ -126,9 +129,31 @@ def ctc_beam_search_decoder_batch(probs_split,
"""
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch(
batch_beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoding_batch(
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func, blank_id)
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results]
return batch_beam_results
class CTCBeamSearchDecoder(paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch):
"""Wrapper for CtcBeamSearchDecoderBatch.
Args:
vocab_list (list): Vocabulary list.
beam_size (int): Width for beam search.
num_processes (int): Number of parallel processes.
param cutoff_prob (float): Cutoff probability in vocabulary pruning,
default 1.0, no pruning.
cutoff_top_n (int): Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
param ext_scorer (Scorer): External scorer for partially decoded sentence, e.g. word count
or language model.
"""
def __init__(self, vocab_list, batch_size, beam_size, num_processes,
cutoff_prob, cutoff_top_n, _ext_scorer, blank_id):
paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch.__init__(
self, vocab_list, batch_size, beam_size, num_processes, cutoff_prob,
cutoff_top_n, _ext_scorer, blank_id)

@ -267,12 +267,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer
vocab_list = self.test_loader.collate_fn.vocab_list
target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.compute_result_transcripts(
audio, audio_len, vocab_list, decode_cfg)
result_transcripts = self.compute_result_transcripts(audio, audio_len)
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
@ -296,21 +293,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate=errors_sum / len_refs,
error_rate_type=decode_cfg.error_rate_type)
def compute_result_transcripts(self, audio, audio_len, vocab_list,
decode_cfg):
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=decode_cfg.decoding_method,
lang_model_path=decode_cfg.lang_model_path,
beam_alpha=decode_cfg.alpha,
beam_beta=decode_cfg.beta,
beam_size=decode_cfg.beam_size,
cutoff_prob=decode_cfg.cutoff_prob,
cutoff_top_n=decode_cfg.cutoff_top_n,
num_processes=decode_cfg.num_proc_bsearch)
def compute_result_transcripts(self, audio, audio_len):
result_transcripts = self.model.decode(audio, audio_len)
return result_transcripts
@mp_tools.rank_zero_only
@ -320,6 +304,17 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.model.eval()
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
# Initialized the decoder in model
decode_cfg = self.config.decode
vocab_list = self.test_loader.collate_fn.vocab_list
decode_batch_size = self.test_loader.batch_size
self.model.decoder.init_decoder(
decode_batch_size, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
@ -339,6 +334,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
self.model.decoder.del_decoder()
@paddle.no_grad()
def export(self):
@ -377,6 +373,22 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
self.model.eval()
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
# Initialized the decoder in model
decode_cfg = self.config.decode
vocab_list = self.test_loader.collate_fn.vocab_list
if self.args.model_type == "online":
decode_batch_size = 1
elif self.args.model_type == "offline":
decode_batch_size = self.test_loader.batch_size
else:
raise Exception("wrong model type")
self.model.decoder.init_decoder(
decode_batch_size, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
@ -388,7 +400,6 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
error_rate_type = metrics['error_rate_type']
logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs))
# logging
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
@ -398,30 +409,31 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
logger.info(msg)
if self.args.enable_auto_log is True:
self.autolog.report()
self.model.decoder.del_decoder()
def compute_result_transcripts(self, audio, audio_len, vocab_list,
decode_cfg):
def compute_result_transcripts(self, audio, audio_len):
if self.args.model_type == "online":
output_probs, output_lens = self.static_forward_online(audio,
audio_len)
output_probs, output_lens, trans_batch = self.static_forward_online(
audio, audio_len, decoder_chunk_size=1)
result_transcripts = [trans[-1] for trans in trans_batch]
elif self.args.model_type == "offline":
output_probs, output_lens = self.static_forward_offline(audio,
audio_len)
batch_size = output_probs.shape[0]
self.model.decoder.reset_decoder(batch_size=batch_size)
self.model.decoder.next(output_probs, output_lens)
trans_best, trans_beam = self.model.decoder.decode()
result_transcripts = trans_best
else:
raise Exception("wrong model type")
self.predictor.clear_intermediate_tensor()
self.predictor.try_shrink_memory()
self.model.decoder.init_decode(decode_cfg.alpha, decode_cfg.beta,
decode_cfg.lang_model_path, vocab_list,
decode_cfg.decoding_method)
result_transcripts = self.model.decoder.decode_probs(
output_probs, output_lens, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
#replace the <space> with ' '
result_transcripts = [
self._text_featurizer.detokenize(sentence)
@ -451,6 +463,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
-------
output_probs(numpy.array): shape[B, T, vocab_size]
output_lens(numpy.array): shape[B]
trans(list(list(str))): shape[B, T]
"""
output_probs_list = []
output_lens_list = []
@ -464,14 +477,15 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
batch_size, Tmax, x_dim = x_batch.shape
x_len_batch = audio_len.numpy().astype(np.int64)
if (Tmax - chunk_size) % chunk_stride != 0:
padding_len_batch = chunk_stride - (
Tmax - chunk_size
) % chunk_stride # The length of padding for the batch
# The length of padding for the batch
padding_len_batch = chunk_stride - (Tmax - chunk_size
) % chunk_stride
else:
padding_len_batch = 0
x_list = np.split(x_batch, batch_size, axis=0)
x_len_list = np.split(x_len_batch, batch_size, axis=0)
trans_batch = []
for x, x_len in zip(x_list, x_len_list):
if self.args.enable_auto_log is True:
self.autolog.times.start()
@ -504,12 +518,14 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3])
trans = []
probs_chunk_list = []
probs_chunk_lens_list = []
if self.args.enable_auto_log is True:
# record the model preprocessing time
self.autolog.times.stamp()
self.model.decoder.reset_decoder(batch_size=1)
for i in range(0, num_chunk):
start = i * chunk_stride
end = start + chunk_size
@ -518,9 +534,8 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
x_chunk_lens = 0
else:
x_chunk_lens = min(x_len - i * chunk_stride, chunk_size)
if (x_chunk_lens <
receptive_field_length): #means the number of input frames in the chunk is not enough for predicting one prob
#means the number of input frames in the chunk is not enough for predicting one prob
if (x_chunk_lens < receptive_field_length):
break
x_chunk_lens = np.array([x_chunk_lens])
audio_handle.reshape(x_chunk.shape)
@ -549,9 +564,12 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
output_chunk_lens = output_lens_handle.copy_to_cpu()
chunk_state_h_box = output_state_h_handle.copy_to_cpu()
chunk_state_c_box = output_state_c_handle.copy_to_cpu()
self.model.decoder.next(output_chunk_probs, output_chunk_lens)
probs_chunk_list.append(output_chunk_probs)
probs_chunk_lens_list.append(output_chunk_lens)
trans_best, trans_beam = self.model.decoder.decode()
trans.append(trans_best[0])
trans_batch.append(trans)
output_probs = np.concatenate(probs_chunk_list, axis=1)
output_lens = np.sum(probs_chunk_lens_list, axis=0)
vocab_size = output_probs.shape[2]
@ -573,7 +591,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
self.autolog.times.end()
output_probs = np.concatenate(output_probs_list, axis=0)
output_lens = np.concatenate(output_lens_list, axis=0)
return output_probs, output_lens
return output_probs, output_lens, trans_batch
def static_forward_offline(self, audio, audio_len):
"""

@ -16,7 +16,7 @@ from .deepspeech2 import DeepSpeech2Model
from paddlespeech.s2t.utils import dynamic_pip_install
try:
import swig_decoders
import paddlespeech_ctcdecoders
except ImportError:
try:
package_name = 'paddlespeech_ctcdecoders'

@ -164,24 +164,18 @@ class DeepSpeech2Model(nn.Layer):
return loss
@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
def decode(self, audio, audio_len):
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
# Make sure the decoder has been initialized
eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
batch_size = probs.shape[0]
self.decoder.reset_decoder(batch_size=batch_size)
self.decoder.next(probs, eouts_len)
trans_best, trans_beam = self.decoder.decode()
return trans_best
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):

@ -16,7 +16,7 @@ from .deepspeech2 import DeepSpeech2ModelOnline
from paddlespeech.s2t.utils import dynamic_pip_install
try:
import swig_decoders
import paddlespeech_ctcdecoders
except ImportError:
try:
package_name = 'paddlespeech_ctcdecoders'

@ -293,25 +293,17 @@ class DeepSpeech2ModelOnline(nn.Layer):
return loss
@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
def decode(self, audio, audio_len):
# decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
# Make sure the decoder has been initialized
eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len, None, None)
probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs(
probs.numpy(), eouts_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes)
batch_size = probs.shape[0]
self.decoder.reset_decoder(batch_size=batch_size)
self.decoder.next(probs, eouts_len)
trans_best, trans_beam = self.decoder.decode()
return trans_best
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):

@ -32,7 +32,7 @@ from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.models.asr_interface import ASRInterface
from paddlespeech.s2t.modules.cmvn import GlobalCMVN
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.modules.ctc import CTCDecoderBase
from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder
@ -63,7 +63,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
vocab_size: int,
encoder: TransformerEncoder,
decoder: TransformerDecoder,
ctc: CTCDecoder,
ctc: CTCDecoderBase,
ctc_weight: float=0.5,
ignore_id: int=IGNORE_ID,
lsm_weight: float=0.0,
@ -840,7 +840,7 @@ class U2Model(U2DecodeModel):
model_conf = configs.get('model_conf', dict())
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
grad_norm_type = model_conf.get('ctc_grad_norm_type', None)
ctc = CTCDecoder(
ctc = CTCDecoderBase(
odim=vocab_size,
enc_n_units=encoder.output_size(),
blank_id=0,

@ -28,7 +28,7 @@ from paddle import nn
from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.modules.cmvn import GlobalCMVN
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.modules.ctc import CTCDecoderBase
from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder
@ -56,7 +56,7 @@ class U2STBaseModel(nn.Layer):
encoder: TransformerEncoder,
st_decoder: TransformerDecoder,
decoder: TransformerDecoder=None,
ctc: CTCDecoder=None,
ctc: CTCDecoderBase=None,
ctc_weight: float=0.0,
asr_weight: float=0.0,
ignore_id: int=IGNORE_ID,
@ -313,8 +313,7 @@ class U2STBaseModel(nn.Layer):
cache = [
paddle.ones(
(len(hyps), i - 1, hyp_cache.shape[-1]),
dtype=paddle.float32)
for hyp_cache in hyps[0]["cache"]
dtype=paddle.float32) for hyp_cache in hyps[0]["cache"]
]
for j, hyp in enumerate(hyps):
ys[j, :] = paddle.to_tensor(hyp["yseq"])
@ -596,7 +595,7 @@ class U2STModel(U2STBaseModel):
model_conf = configs['model_conf']
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
grad_norm_type = model_conf.get('ctc_grad_norm_type', None)
ctc = CTCDecoder(
ctc = CTCDecoderBase(
odim=vocab_size,
enc_n_units=encoder.output_size(),
blank_id=0,

@ -25,17 +25,19 @@ from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
try:
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_greedy_decoder # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import Scorer # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder import ctc_beam_search_decoding_batch # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder import ctc_greedy_decoding # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder import CTCBeamSearchDecoder # noqa: F401
except ImportError:
try:
from paddlespeech.s2t.utils import dynamic_pip_install
package_name = 'paddlespeech_ctcdecoders'
dynamic_pip_install.install(package_name)
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_greedy_decoder # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import Scorer # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder import ctc_beam_search_decoding_batch # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder import ctc_greedy_decoding # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder import CTCBeamSearchDecoder # noqa: F401
except Exception as e:
logger.info("paddlespeech_ctcdecoders not installed!")
@ -139,9 +141,11 @@ class CTCDecoder(CTCDecoderBase):
super().__init__(*args, **kwargs)
# CTCDecoder LM Score handle
self._ext_scorer = None
self.beam_search_decoder = None
def _decode_batch_greedy(self, probs_split, vocab_list):
"""Decode by best path for a batch of probs matrix input.
def _decode_batch_greedy_offline(self, probs_split, vocab_list):
"""This function will be deprecated in future.
Decode by best path for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce.
:param probs_split: List of matrix
@ -152,7 +156,7 @@ class CTCDecoder(CTCDecoderBase):
"""
results = []
for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder(
output_transcription = ctc_greedy_decoding(
probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id)
results.append(output_transcription)
return results
@ -194,10 +198,12 @@ class CTCDecoder(CTCDecoderBase):
logger.info("no language model provided, "
"decoding by pure beam search without scorer.")
def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
vocab_list, num_processes):
"""Decode by beam search for a batch of probs matrix input.
def _decode_batch_beam_search_offline(
self, probs_split, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, vocab_list, num_processes):
"""
This function will be deprecated in future.
Decode by beam search for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce.
:param probs_split: List of matrix
@ -226,7 +232,7 @@ class CTCDecoder(CTCDecoderBase):
# beam search decode
num_processes = min(num_processes, len(probs_split))
beam_search_results = ctc_beam_search_decoder_batch(
beam_search_results = ctc_beam_search_decoding_batch(
probs_split=probs_split,
vocabulary=vocab_list,
beam_size=beam_size,
@ -239,30 +245,69 @@ class CTCDecoder(CTCDecoderBase):
results = [result[0][1] for result in beam_search_results]
return results
def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list,
decoding_method):
def init_decoder(self, batch_size, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_prob, cutoff_top_n, num_processes):
"""
init ctc decoders
Args:
batch_size(int): Batch size for input data
vocab_list (list): List of tokens in the vocabulary, for decoding
decoding_method (str): ctc_beam_search
lang_model_path (str): language model path
beam_alpha (float): beam_alpha
beam_beta (float): beam_beta
beam_size (int): beam_size
cutoff_prob (float): cutoff probability in beam search
cutoff_top_n (int): cutoff_top_n
num_processes (int): num_processes
Raises:
ValueError: when decoding_method not support.
Returns:
CTCBeamSearchDecoder
"""
self.batch_size = batch_size
self.vocab_list = vocab_list
self.decoding_method = decoding_method
self.beam_size = beam_size
self.cutoff_prob = cutoff_prob
self.cutoff_top_n = cutoff_top_n
self.num_processes = num_processes
if decoding_method == "ctc_beam_search":
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
vocab_list)
if self.beam_search_decoder is None:
self.beam_search_decoder = self.get_decoder(
vocab_list, batch_size, beam_alpha, beam_beta, beam_size,
num_processes, cutoff_prob, cutoff_top_n)
return self.beam_search_decoder
elif decoding_method == "ctc_greedy":
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
vocab_list)
else:
raise ValueError(f"Not support: {decoding_method}")
def decode_probs(self, probs, logits_lens, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_prob, cutoff_top_n, num_processes):
"""ctc decoding with probs.
def decode_probs_offline(self, probs, logits_lens, vocab_list,
decoding_method, lang_model_path, beam_alpha,
beam_beta, beam_size, cutoff_prob, cutoff_top_n,
num_processes):
"""
This function will be deprecated in future.
ctc decoding with probs.
Args:
probs (Tensor): activation after softmax
logits_lens (Tensor): audio output lens
vocab_list ([type]): [description]
decoding_method ([type]): [description]
lang_model_path ([type]): [description]
beam_alpha ([type]): [description]
beam_beta ([type]): [description]
beam_size ([type]): [description]
cutoff_prob ([type]): [description]
cutoff_top_n ([type]): [description]
num_processes ([type]): [description]
vocab_list (list): List of tokens in the vocabulary, for decoding
decoding_method (str): ctc_beam_search
lang_model_path (str): language model path
beam_alpha (float): beam_alpha
beam_beta (float): beam_beta
beam_size (int): beam_size
cutoff_prob (float): cutoff probability in beam search
cutoff_top_n (int): cutoff_top_n
num_processes (int): num_processes
Raises:
ValueError: when decoding_method not support.
@ -270,13 +315,14 @@ class CTCDecoder(CTCDecoderBase):
Returns:
List[str]: transcripts.
"""
logger.warn(
"This function will be deprecated in future: decode_probs_offline")
probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)]
if decoding_method == "ctc_greedy":
result_transcripts = self._decode_batch_greedy(
result_transcripts = self._decode_batch_greedy_offline(
probs_split=probs_split, vocab_list=vocab_list)
elif decoding_method == "ctc_beam_search":
result_transcripts = self._decode_batch_beam_search(
result_transcripts = self._decode_batch_beam_search_offline(
probs_split=probs_split,
beam_alpha=beam_alpha,
beam_beta=beam_beta,
@ -288,3 +334,136 @@ class CTCDecoder(CTCDecoderBase):
else:
raise ValueError(f"Not support: {decoding_method}")
return result_transcripts
def get_decoder(self, vocab_list, batch_size, beam_alpha, beam_beta,
beam_size, num_processes, cutoff_prob, cutoff_top_n):
"""
init get ctc decoder
Args:
vocab_list (list): List of tokens in the vocabulary, for decoding.
batch_size(int): Batch size for input data
beam_alpha (float): beam_alpha
beam_beta (float): beam_beta
beam_size (int): beam_size
num_processes (int): num_processes
cutoff_prob (float): cutoff probability in beam search
cutoff_top_n (int): cutoff_top_n
Raises:
ValueError: when decoding_method not support.
Returns:
CTCBeamSearchDecoder
"""
num_processes = min(num_processes, batch_size)
if self._ext_scorer is not None:
self._ext_scorer.reset_params(beam_alpha, beam_beta)
if self.decoding_method == "ctc_beam_search":
beam_search_decoder = CTCBeamSearchDecoder(
vocab_list, batch_size, beam_size, num_processes, cutoff_prob,
cutoff_top_n, self._ext_scorer, self.blank_id)
else:
raise ValueError(f"Not support: {decoding_method}")
return beam_search_decoder
def next(self, probs, logits_lens):
"""
Input probs into ctc decoder
Args:
probs (list(list(float))): probs for a batch of data
logits_lens (list(int)): logits lens for a batch of data
Raises:
Exception: when the ctc decoder is not initialized
ValueError: when decoding_method not support.
"""
if self.beam_search_decoder is None:
raise Exception(
"You need to initialize the beam_search_decoder firstly")
beam_search_decoder = self.beam_search_decoder
has_value = (logits_lens > 0).tolist()
has_value = [
"true" if has_value[i] is True else "false"
for i in range(len(has_value))
]
probs_split = [
probs[i, :l, :].tolist() if has_value[i] else probs[i].tolist()
for i, l in enumerate(logits_lens)
]
if self.decoding_method == "ctc_beam_search":
beam_search_decoder.next(probs_split, has_value)
else:
raise ValueError(f"Not support: {decoding_method}")
return
def decode(self):
"""
Get the decoding result
Raises:
Exception: when the ctc decoder is not initialized
ValueError: when decoding_method not support.
Returns:
results_best (list(str)): The best result for a batch of data
results_beam (list(list(str))): The beam search result for a batch of data
"""
if self.beam_search_decoder is None:
raise Exception(
"You need to initialize the beam_search_decoder firstly")
beam_search_decoder = self.beam_search_decoder
if self.decoding_method == "ctc_beam_search":
batch_beam_results = beam_search_decoder.decode()
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results]
results_best = [result[0][1] for result in batch_beam_results]
results_beam = [[trans[1] for trans in result]
for result in batch_beam_results]
else:
raise ValueError(f"Not support: {decoding_method}")
return results_best, results_beam
def reset_decoder(self,
batch_size=-1,
beam_size=-1,
num_processes=-1,
cutoff_prob=-1.0,
cutoff_top_n=-1):
if batch_size > 0:
self.batch_size = batch_size
if beam_size > 0:
self.beam_size = beam_size
if num_processes > 0:
self.num_processes = num_processes
if cutoff_prob > 0:
self.cutoff_prob = cutoff_prob
if cutoff_top_n > 0:
self.cutoff_top_n = cutoff_top_n
"""
Reset the decoder state
Args:
batch_size(int): Batch size for input data
beam_size (int): beam_size
num_processes (int): num_processes
cutoff_prob (float): cutoff probability in beam search
cutoff_top_n (int): cutoff_top_n
Raises:
Exception: when the ctc decoder is not initialized
"""
if self.beam_search_decoder is None:
raise Exception(
"You need to initialize the beam_search_decoder firstly")
self.beam_search_decoder.reset_state(
self.batch_size, self.beam_size, self.num_processes,
self.cutoff_prob, self.cutoff_top_n)
def del_decoder(self):
"""
Delete the decoder
"""
if self.beam_search_decoder is not None:
del self.beam_search_decoder
self.beam_search_decoder = None

@ -29,7 +29,8 @@
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::pair<double, std::string>> ctc_beam_search_decoding(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
@ -46,6 +47,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
// assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
int space_id = it - vocabulary.begin();
@ -206,7 +209,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
ctc_beam_search_decoding_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &vocabulary,
size_t beam_size,
@ -224,7 +227,7 @@ ctc_beam_search_decoder_batch(
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
res.emplace_back(pool.enqueue(ctc_beam_search_decoding,
probs_split[i],
vocabulary,
beam_size,
@ -241,3 +244,364 @@ ctc_beam_search_decoder_batch(
}
return batch_results;
}
void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer) {
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher);
}
}
void ctc_beam_search_decode_chunk(
PathTrie *root,
std::vector<PathTrie *> &prefixes,
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
// vocabulary.size() + 1,
vocabulary.size(),
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
// assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) {
space_id = -2;
}
// init prefixes' root
//
// prefix search over time
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
auto &prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
}
std::vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second;
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
// blank
if (c == blank_id) {
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur =
log_sum_exp(prefix->log_prob_nb_cur,
log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) *
ext_scorer->alpha;
log_p += score;
log_p += ext_scorer->beta;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over vocabulary
prefixes.clear();
// update log probs
root->iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
}
} // end of loop over time
return;
}
std::vector<std::pair<double, std::string>> get_decode_result(
std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary,
size_t beam_size,
Scorer *ext_scorer) {
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) {
space_id = -2;
}
// score the last word of each prefix that doesn't end with space
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score =
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score += ext_scorer->beta;
prefix->score += score;
}
}
}
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
if (ext_scorer != nullptr) {
std::vector<int> output;
prefixes[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = ext_scorer->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight:
approx_ctc -=
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}
std::vector<std::pair<double, std::string>> res =
get_beam_search_result(prefixes, vocabulary, beam_size);
// pay back the last word of each prefix that doesn't end with space (for
// decoding by chunk)
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score =
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score += ext_scorer->beta;
prefix->score -= score;
}
}
}
return res;
}
void free_storage(std::unique_ptr<CtcBeamSearchDecoderStorage> &storage) {
storage = nullptr;
}
CtcBeamSearchDecoderBatch::~CtcBeamSearchDecoderBatch() {}
CtcBeamSearchDecoderBatch::CtcBeamSearchDecoderBatch(
const std::vector<std::string> &vocabulary,
size_t batch_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id)
: batch_size(batch_size),
beam_size(beam_size),
num_processes(num_processes),
cutoff_prob(cutoff_prob),
cutoff_top_n(cutoff_top_n),
ext_scorer(ext_scorer),
blank_id(blank_id) {
VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!");
VALID_CHECK_GT(
this->num_processes, 0, "num_processes must be nonnegative!");
this->vocabulary = vocabulary;
for (size_t i = 0; i < batch_size; i++) {
this->decoder_storage_vector.push_back(
std::unique_ptr<CtcBeamSearchDecoderStorage>(
new CtcBeamSearchDecoderStorage()));
ctc_beam_search_decode_chunk_begin(
this->decoder_storage_vector[i]->root, ext_scorer);
}
};
/**
* Input
* probs_split: shape [B, T, D]
*/
void CtcBeamSearchDecoderBatch::next(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &has_value) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool
size_t num_has_value = 0;
for (int i = 0; i < has_value.size(); i++)
if (has_value[i] == "true") num_has_value += 1;
ThreadPool pool(std::min(num_processes, num_has_value));
// number of samples
size_t probs_num = probs_split.size();
VALID_CHECK_EQ(this->batch_size,
probs_num,
"The batch size of the current input data should be same "
"with the input data before");
// enqueue the tasks of decoding
std::vector<std::future<void>> res;
for (size_t i = 0; i < batch_size; ++i) {
if (has_value[i] == "true") {
res.emplace_back(pool.enqueue(
ctc_beam_search_decode_chunk,
std::ref(this->decoder_storage_vector[i]->root),
std::ref(this->decoder_storage_vector[i]->prefixes),
probs_split[i],
this->vocabulary,
this->beam_size,
this->cutoff_prob,
this->cutoff_top_n,
this->ext_scorer,
this->blank_id));
}
}
for (size_t i = 0; i < batch_size; ++i) {
res[i].get();
}
return;
};
/**
* Return
* batch_result: shape[B, beam_size,(-approx_ctc score, string)]
*/
std::vector<std::vector<std::pair<double, std::string>>>
CtcBeamSearchDecoderBatch::decode() {
VALID_CHECK_GT(
this->num_processes, 0, "num_processes must be nonnegative!");
// thread pool
ThreadPool pool(this->num_processes);
// number of samples
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (size_t i = 0; i < this->batch_size; ++i) {
res.emplace_back(
pool.enqueue(get_decode_result,
std::ref(this->decoder_storage_vector[i]->prefixes),
this->vocabulary,
this->beam_size,
this->ext_scorer));
}
// get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (size_t i = 0; i < this->batch_size; ++i) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
}
/**
* reset the state of ctcBeamSearchDecoderBatch
*/
void CtcBeamSearchDecoderBatch::reset_state(size_t batch_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n) {
this->batch_size = batch_size;
this->beam_size = beam_size;
this->num_processes = num_processes;
this->cutoff_prob = cutoff_prob;
this->cutoff_top_n = cutoff_top_n;
VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!");
VALID_CHECK_GT(
this->num_processes, 0, "num_processes must be nonnegative!");
// thread pool
ThreadPool pool(this->num_processes);
// number of samples
// enqueue the tasks of decoding
std::vector<std::future<void>> res;
size_t storage_size = decoder_storage_vector.size();
for (size_t i = 0; i < storage_size; i++) {
res.emplace_back(pool.enqueue(
free_storage, std::ref(this->decoder_storage_vector[i])));
}
for (size_t i = 0; i < storage_size; ++i) {
res[i].get();
}
std::vector<std::unique_ptr<CtcBeamSearchDecoderStorage>>().swap(
decoder_storage_vector);
for (size_t i = 0; i < this->batch_size; i++) {
this->decoder_storage_vector.push_back(
std::unique_ptr<CtcBeamSearchDecoderStorage>(
new CtcBeamSearchDecoderStorage()));
ctc_beam_search_decode_chunk_begin(
this->decoder_storage_vector[i]->root, this->ext_scorer);
}
}

@ -37,7 +37,7 @@
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::pair<double, std::string>> ctc_beam_search_decoding(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
@ -46,6 +46,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
/* CTC Beam Search Decoder for batch data
* Parameters:
@ -64,7 +65,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
* result for one audio sample.
*/
std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch(
ctc_beam_search_decoding_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &vocabulary,
size_t beam_size,
@ -74,4 +75,101 @@ ctc_beam_search_decoder_batch(
Scorer *ext_scorer = nullptr,
size_t blank_id = 0);
/**
* Store the root and prefixes for decoder
*/
class CtcBeamSearchDecoderStorage {
public:
PathTrie *root = nullptr;
std::vector<PathTrie *> prefixes;
CtcBeamSearchDecoderStorage() {
// init prefixes' root
this->root = new PathTrie();
this->root->log_prob_b_prev = 0.0;
// The score of root is in log scale.Since the prob=1.0, the prob score
// in log scale is 0.0
this->root->score = root->log_prob_b_prev;
// std::vector<PathTrie *> prefixes;
this->prefixes.push_back(root);
};
~CtcBeamSearchDecoderStorage() {
if (root != nullptr) {
delete root;
root = nullptr;
}
};
};
/**
* The ctc beam search decoder, support batchsize >= 1
*/
class CtcBeamSearchDecoderBatch {
public:
CtcBeamSearchDecoderBatch(const std::vector<std::string> &vocabulary,
size_t batch_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id);
~CtcBeamSearchDecoderBatch();
void next(const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &has_value);
std::vector<std::vector<std::pair<double, std::string>>> decode();
void reset_state(size_t batch_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n);
private:
std::vector<std::string> vocabulary;
size_t batch_size;
size_t beam_size;
size_t num_processes;
double cutoff_prob;
size_t cutoff_top_n;
Scorer *ext_scorer;
size_t blank_id;
std::vector<std::unique_ptr<CtcBeamSearchDecoderStorage>>
decoder_storage_vector;
};
/**
* function for chunk decoding
*/
void ctc_beam_search_decode_chunk(
PathTrie *root,
std::vector<PathTrie *> &prefixes,
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id);
std::vector<std::pair<double, std::string>> get_decode_result(
std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary,
size_t beam_size,
Scorer *ext_scorer);
/**
* free the CtcBeamSearchDecoderStorage
*/
void free_storage(std::unique_ptr<CtcBeamSearchDecoderStorage> &storage);
/**
* initialize the root
*/
void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer);
#endif // CTC_BEAM_SEARCH_DECODER_H_

@ -15,7 +15,7 @@
#include "ctc_greedy_decoder.h"
#include "decoder_utils.h"
std::string ctc_greedy_decoder(
std::string ctc_greedy_decoding(
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t blank_id) {

@ -27,7 +27,7 @@
* Return:
* The decoding result in string
*/
std::string ctc_greedy_decoder(
std::string ctc_greedy_decoding(
const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary,
size_t blank_id);

@ -1,4 +1,4 @@
%module swig_decoders
%module paddlespeech_ctcdecoders
%{
#include "scorer.h"
#include "ctc_greedy_decoder.h"

@ -44,6 +44,7 @@ PathTrie::PathTrie() {
PathTrie::~PathTrie() {
for (auto child : children_) {
delete child.second;
child.second = nullptr;
}
}
@ -131,26 +132,26 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
void PathTrie::remove() {
exists_ = false;
if (children_.size() == 0) {
auto child = parent->children_.begin();
for (child = parent->children_.begin();
child != parent->children_.end();
++child) {
if (child->first == character) {
parent->children_.erase(child);
break;
if (parent != nullptr) {
auto child = parent->children_.begin();
for (child = parent->children_.begin();
child != parent->children_.end();
++child) {
if (child->first == character) {
parent->children_.erase(child);
break;
}
}
if (parent->children_.size() == 0 && !parent->exists_) {
parent->remove();
}
}
if (parent->children_.size() == 0 && !parent->exists_) {
parent->remove();
}
delete this;
}
}
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
dictionary_ = dictionary;
dictionary_state_ = dictionary->Start();

@ -1,4 +1,5 @@
// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the "COPYING.LESSER.3");
// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the
// "COPYING.LESSER.3");
#include "scorer.h"

@ -1,4 +1,5 @@
// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the "COPYING.LESSER.3");
// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the
// "COPYING.LESSER.3");
#ifndef SCORER_H_
#define SCORER_H_

@ -112,7 +112,7 @@ os.system('swig -python -c++ ./decoders.i')
decoders_module = [
Extension(
name='_swig_decoders',
name='_paddlespeech_ctcdecoders',
sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'),
language='c++',
include_dirs=[
@ -134,4 +134,4 @@ setup(
url="https://github.com/PaddlePaddle/PaddleSpeech",
license='Apache 2.0, GNU Lesser General Public License v3 (LGPLv3) (LGPL-3)',
ext_modules=decoders_module,
py_modules=['swig_decoders'])
py_modules=['paddlespeech_ctcdecoders'])

Loading…
Cancel
Save