[ASR] Support CTC decoder online (#821)

* fix the destructer problem for prefixes

* unified offline and online in ctcdecoders, test=asr

* rename swig_decoders to paddlespeech_ctcdecoders, test=asr

* add reset_stage for ctcdecoder

* fix some problems

* fix ctconline

* fix a bug

* fix the format

* fix 1xt2x
pull/1380/head
Jackwaterveg 2 years ago committed by GitHub
parent 3dedea8582
commit d7222c0453
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -162,39 +162,17 @@ class DeepSpeech2Model(nn.Layer):
return loss return loss
@paddle.no_grad() @paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method, def decode(self, audio, audio_len):
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8 # decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
# Make sure the decoder has been initialized
eouts, eouts_len = self.encoder(audio, audio_len) eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts) probs = self.decoder.softmax(eouts)
print("probs.shape", probs.shape) batch_size = probs.shape[0]
return self.decoder.decode_probs( self.decoder.reset_decoder(batch_size = batch_size)
probs.numpy(), eouts_len, vocab_list, decoding_method, self.decoder.next(probs, eouts_len)
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, trans_best, trans_beam = self.decoder.decode()
cutoff_top_n, num_processes) return trans_best
def decode_probs_split(self, probs_split, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_prob, cutoff_top_n, num_processes):
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
return self.decoder.decode_probs_split(
probs_split, vocab_list, decoding_method, lang_model_path,
beam_alpha, beam_beta, beam_size, cutoff_prob, cutoff_top_n,
num_processes)
@classmethod @classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path): def from_pretrained(cls, dataloader, config, checkpoint_path):

@ -254,12 +254,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
vocab_list = self.test_loader.collate_fn.vocab_list
target_transcripts = self.ordid2token(texts, texts_len) target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.compute_result_transcripts(audio, audio_len, result_transcripts = self.compute_result_transcripts(audio, audio_len)
vocab_list, cfg)
for utt, target, result in zip(utts, target_transcripts, for utt, target, result in zip(utts, target_transcripts,
result_transcripts): result_transcripts):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
@ -280,19 +278,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate=errors_sum / len_refs, error_rate=errors_sum / len_refs,
error_rate_type=cfg.error_rate_type) error_rate_type=cfg.error_rate_type)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): def compute_result_transcripts(self, audio, audio_len):
result_transcripts = self.model.decode( result_transcripts = self.model.decode(audio, audio_len)
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
result_transcripts = [ result_transcripts = [
self._text_featurizer.detokenize(item) self._text_featurizer.detokenize(item)
for item in result_transcripts for item in result_transcripts
@ -307,6 +295,17 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cfg = self.config cfg = self.config
error_rate_type = None error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
# Initialized the decoder in model
decode_cfg = self.config.decode
vocab_list = self.test_loader.collate_fn.vocab_list
decode_batch_size = self.test_loader.batch_size
self.model.decoder.init_decoder(
decode_batch_size, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
with open(self.args.result_file, 'w') as fout: with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch utts, audio, audio_len, texts, texts_len = batch
@ -326,6 +325,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
msg += "Final error rate [%s] (%d/%d) = %f" % ( msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs) error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg) logger.info(msg)
self.model.decoder.del_decoder()
def run_test(self): def run_test(self):
self.resume_or_scratch() self.resume_or_scratch()

@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .swig_wrapper import ctc_beam_search_decoding
from .swig_wrapper import ctc_beam_search_decoding_batch
from .swig_wrapper import ctc_greedy_decoding
from .swig_wrapper import CTCBeamSearchDecoder
from .swig_wrapper import Scorer

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Wrapper for various CTC decoders in SWIG.""" """Wrapper for various CTC decoders in SWIG."""
import swig_decoders import paddlespeech_ctcdecoders
class Scorer(swig_decoders.Scorer): class Scorer(paddlespeech_ctcdecoders.Scorer):
"""Wrapper for Scorer. """Wrapper for Scorer.
:param alpha: Parameter associated with language model. Don't use :param alpha: Parameter associated with language model. Don't use
@ -26,14 +26,17 @@ class Scorer(swig_decoders.Scorer):
:type beta: float :type beta: float
:model_path: Path to load language model. :model_path: Path to load language model.
:type model_path: str :type model_path: str
:param vocabulary: Vocabulary list.
:type vocabulary: list
""" """
def __init__(self, alpha, beta, model_path, vocabulary): def __init__(self, alpha, beta, model_path, vocabulary):
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary) paddlespeech_ctcdecoders.Scorer.__init__(self, alpha, beta, model_path,
vocabulary)
def ctc_greedy_decoder(probs_seq, vocabulary, blank_id): def ctc_greedy_decoding(probs_seq, vocabulary, blank_id):
"""Wrapper for ctc best path decoder in swig. """Wrapper for ctc best path decodeing function in swig.
:param probs_seq: 2-D list of probability distributions over each time :param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized step, with each element being a list of normalized
@ -44,19 +47,19 @@ def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
:return: Decoding result string. :return: Decoding result string.
:rtype: str :rtype: str
""" """
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary, result = paddlespeech_ctcdecoders.ctc_greedy_decoding(probs_seq.tolist(),
blank_id) vocabulary, blank_id)
return result return result
def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoding(probs_seq,
vocabulary, vocabulary,
beam_size, beam_size,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
ext_scoring_func=None, ext_scoring_func=None,
blank_id=0): blank_id=0):
"""Wrapper for the CTC Beam Search Decoder. """Wrapper for the CTC Beam Search Decoding function.
:param probs_seq: 2-D list of probability distributions over each time :param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized step, with each element being a list of normalized
@ -81,22 +84,22 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the probability. results, in descending order of the probability.
:rtype: list :rtype: list
""" """
beam_results = swig_decoders.ctc_beam_search_decoder( beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoding(
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n, probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
ext_scoring_func, blank_id) ext_scoring_func, blank_id)
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results] beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
return beam_results return beam_results
def ctc_beam_search_decoder_batch(probs_split, def ctc_beam_search_decoding_batch(probs_split,
vocabulary, vocabulary,
beam_size, beam_size,
num_processes, num_processes,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
ext_scoring_func=None, ext_scoring_func=None,
blank_id=0): blank_id=0):
"""Wrapper for the batched CTC beam search decoder. """Wrapper for the batched CTC beam search decodeing batch function.
:param probs_seq: 3-D list with each element as an instance of 2-D list :param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder(). of probabilities used by ctc_beam_search_decoder().
@ -126,9 +129,31 @@ def ctc_beam_search_decoder_batch(probs_split,
""" """
probs_split = [probs_seq.tolist() for probs_seq in probs_split] probs_split = [probs_seq.tolist() for probs_seq in probs_split]
batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch( batch_beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoding_batch(
probs_split, vocabulary, beam_size, num_processes, cutoff_prob, probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func, blank_id) cutoff_top_n, ext_scoring_func, blank_id)
batch_beam_results = [[(res[0], res[1]) for res in beam_results] batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results] for beam_results in batch_beam_results]
return batch_beam_results return batch_beam_results
class CTCBeamSearchDecoder(paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch):
"""Wrapper for CtcBeamSearchDecoderBatch.
Args:
vocab_list (list): Vocabulary list.
beam_size (int): Width for beam search.
num_processes (int): Number of parallel processes.
param cutoff_prob (float): Cutoff probability in vocabulary pruning,
default 1.0, no pruning.
cutoff_top_n (int): Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
param ext_scorer (Scorer): External scorer for partially decoded sentence, e.g. word count
or language model.
"""
def __init__(self, vocab_list, batch_size, beam_size, num_processes,
cutoff_prob, cutoff_top_n, _ext_scorer, blank_id):
paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch.__init__(
self, vocab_list, batch_size, beam_size, num_processes, cutoff_prob,
cutoff_top_n, _ext_scorer, blank_id)

@ -267,12 +267,9 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer
vocab_list = self.test_loader.collate_fn.vocab_list
target_transcripts = self.ordid2token(texts, texts_len) target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.compute_result_transcripts( result_transcripts = self.compute_result_transcripts(audio, audio_len)
audio, audio_len, vocab_list, decode_cfg)
for utt, target, result in zip(utts, target_transcripts, for utt, target, result in zip(utts, target_transcripts,
result_transcripts): result_transcripts):
@ -296,21 +293,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
error_rate=errors_sum / len_refs, error_rate=errors_sum / len_refs,
error_rate_type=decode_cfg.error_rate_type) error_rate_type=decode_cfg.error_rate_type)
def compute_result_transcripts(self, audio, audio_len, vocab_list, def compute_result_transcripts(self, audio, audio_len):
decode_cfg): result_transcripts = self.model.decode(audio, audio_len)
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=decode_cfg.decoding_method,
lang_model_path=decode_cfg.lang_model_path,
beam_alpha=decode_cfg.alpha,
beam_beta=decode_cfg.beta,
beam_size=decode_cfg.beam_size,
cutoff_prob=decode_cfg.cutoff_prob,
cutoff_top_n=decode_cfg.cutoff_top_n,
num_processes=decode_cfg.num_proc_bsearch)
return result_transcripts return result_transcripts
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@ -320,6 +304,17 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.model.eval() self.model.eval()
error_rate_type = None error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
# Initialized the decoder in model
decode_cfg = self.config.decode
vocab_list = self.test_loader.collate_fn.vocab_list
decode_batch_size = self.test_loader.batch_size
self.model.decoder.init_decoder(
decode_batch_size, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
with jsonlines.open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch utts, audio, audio_len, texts, texts_len = batch
@ -339,6 +334,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
msg += "Final error rate [%s] (%d/%d) = %f" % ( msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs) error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg) logger.info(msg)
self.model.decoder.del_decoder()
@paddle.no_grad() @paddle.no_grad()
def export(self): def export(self):
@ -377,6 +373,22 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
self.model.eval() self.model.eval()
error_rate_type = None error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
# Initialized the decoder in model
decode_cfg = self.config.decode
vocab_list = self.test_loader.collate_fn.vocab_list
if self.args.model_type == "online":
decode_batch_size = 1
elif self.args.model_type == "offline":
decode_batch_size = self.test_loader.batch_size
else:
raise Exception("wrong model type")
self.model.decoder.init_decoder(
decode_batch_size, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
with jsonlines.open(self.args.result_file, 'w') as fout: with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch utts, audio, audio_len, texts, texts_len = batch
@ -388,7 +400,6 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
error_rate_type = metrics['error_rate_type'] error_rate_type = metrics['error_rate_type']
logger.info("Error rate [%s] (%d/?) = %f" % logger.info("Error rate [%s] (%d/?) = %f" %
(error_rate_type, num_ins, errors_sum / len_refs)) (error_rate_type, num_ins, errors_sum / len_refs))
# logging # logging
msg = "Test: " msg = "Test: "
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
@ -398,30 +409,31 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
logger.info(msg) logger.info(msg)
if self.args.enable_auto_log is True: if self.args.enable_auto_log is True:
self.autolog.report() self.autolog.report()
self.model.decoder.del_decoder()
def compute_result_transcripts(self, audio, audio_len, vocab_list, def compute_result_transcripts(self, audio, audio_len):
decode_cfg):
if self.args.model_type == "online": if self.args.model_type == "online":
output_probs, output_lens = self.static_forward_online(audio, output_probs, output_lens, trans_batch = self.static_forward_online(
audio_len) audio, audio_len, decoder_chunk_size=1)
result_transcripts = [trans[-1] for trans in trans_batch]
elif self.args.model_type == "offline": elif self.args.model_type == "offline":
output_probs, output_lens = self.static_forward_offline(audio, output_probs, output_lens = self.static_forward_offline(audio,
audio_len) audio_len)
batch_size = output_probs.shape[0]
self.model.decoder.reset_decoder(batch_size=batch_size)
self.model.decoder.next(output_probs, output_lens)
trans_best, trans_beam = self.model.decoder.decode()
result_transcripts = trans_best
else: else:
raise Exception("wrong model type") raise Exception("wrong model type")
self.predictor.clear_intermediate_tensor() self.predictor.clear_intermediate_tensor()
self.predictor.try_shrink_memory() self.predictor.try_shrink_memory()
self.model.decoder.init_decode(decode_cfg.alpha, decode_cfg.beta,
decode_cfg.lang_model_path, vocab_list,
decode_cfg.decoding_method)
result_transcripts = self.model.decoder.decode_probs(
output_probs, output_lens, vocab_list, decode_cfg.decoding_method,
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
decode_cfg.beam_size, decode_cfg.cutoff_prob,
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
#replace the <space> with ' ' #replace the <space> with ' '
result_transcripts = [ result_transcripts = [
self._text_featurizer.detokenize(sentence) self._text_featurizer.detokenize(sentence)
@ -451,6 +463,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
------- -------
output_probs(numpy.array): shape[B, T, vocab_size] output_probs(numpy.array): shape[B, T, vocab_size]
output_lens(numpy.array): shape[B] output_lens(numpy.array): shape[B]
trans(list(list(str))): shape[B, T]
""" """
output_probs_list = [] output_probs_list = []
output_lens_list = [] output_lens_list = []
@ -464,14 +477,15 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
batch_size, Tmax, x_dim = x_batch.shape batch_size, Tmax, x_dim = x_batch.shape
x_len_batch = audio_len.numpy().astype(np.int64) x_len_batch = audio_len.numpy().astype(np.int64)
if (Tmax - chunk_size) % chunk_stride != 0: if (Tmax - chunk_size) % chunk_stride != 0:
padding_len_batch = chunk_stride - ( # The length of padding for the batch
Tmax - chunk_size padding_len_batch = chunk_stride - (Tmax - chunk_size
) % chunk_stride # The length of padding for the batch ) % chunk_stride
else: else:
padding_len_batch = 0 padding_len_batch = 0
x_list = np.split(x_batch, batch_size, axis=0) x_list = np.split(x_batch, batch_size, axis=0)
x_len_list = np.split(x_len_batch, batch_size, axis=0) x_len_list = np.split(x_len_batch, batch_size, axis=0)
trans_batch = []
for x, x_len in zip(x_list, x_len_list): for x, x_len in zip(x_list, x_len_list):
if self.args.enable_auto_log is True: if self.args.enable_auto_log is True:
self.autolog.times.start() self.autolog.times.start()
@ -504,12 +518,14 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
h_box_handle = self.predictor.get_input_handle(input_names[2]) h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3]) c_box_handle = self.predictor.get_input_handle(input_names[3])
trans = []
probs_chunk_list = [] probs_chunk_list = []
probs_chunk_lens_list = [] probs_chunk_lens_list = []
if self.args.enable_auto_log is True: if self.args.enable_auto_log is True:
# record the model preprocessing time # record the model preprocessing time
self.autolog.times.stamp() self.autolog.times.stamp()
self.model.decoder.reset_decoder(batch_size=1)
for i in range(0, num_chunk): for i in range(0, num_chunk):
start = i * chunk_stride start = i * chunk_stride
end = start + chunk_size end = start + chunk_size
@ -518,9 +534,8 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
x_chunk_lens = 0 x_chunk_lens = 0
else: else:
x_chunk_lens = min(x_len - i * chunk_stride, chunk_size) x_chunk_lens = min(x_len - i * chunk_stride, chunk_size)
#means the number of input frames in the chunk is not enough for predicting one prob
if (x_chunk_lens < if (x_chunk_lens < receptive_field_length):
receptive_field_length): #means the number of input frames in the chunk is not enough for predicting one prob
break break
x_chunk_lens = np.array([x_chunk_lens]) x_chunk_lens = np.array([x_chunk_lens])
audio_handle.reshape(x_chunk.shape) audio_handle.reshape(x_chunk.shape)
@ -549,9 +564,12 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
output_chunk_lens = output_lens_handle.copy_to_cpu() output_chunk_lens = output_lens_handle.copy_to_cpu()
chunk_state_h_box = output_state_h_handle.copy_to_cpu() chunk_state_h_box = output_state_h_handle.copy_to_cpu()
chunk_state_c_box = output_state_c_handle.copy_to_cpu() chunk_state_c_box = output_state_c_handle.copy_to_cpu()
self.model.decoder.next(output_chunk_probs, output_chunk_lens)
probs_chunk_list.append(output_chunk_probs) probs_chunk_list.append(output_chunk_probs)
probs_chunk_lens_list.append(output_chunk_lens) probs_chunk_lens_list.append(output_chunk_lens)
trans_best, trans_beam = self.model.decoder.decode()
trans.append(trans_best[0])
trans_batch.append(trans)
output_probs = np.concatenate(probs_chunk_list, axis=1) output_probs = np.concatenate(probs_chunk_list, axis=1)
output_lens = np.sum(probs_chunk_lens_list, axis=0) output_lens = np.sum(probs_chunk_lens_list, axis=0)
vocab_size = output_probs.shape[2] vocab_size = output_probs.shape[2]
@ -573,7 +591,7 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
self.autolog.times.end() self.autolog.times.end()
output_probs = np.concatenate(output_probs_list, axis=0) output_probs = np.concatenate(output_probs_list, axis=0)
output_lens = np.concatenate(output_lens_list, axis=0) output_lens = np.concatenate(output_lens_list, axis=0)
return output_probs, output_lens return output_probs, output_lens, trans_batch
def static_forward_offline(self, audio, audio_len): def static_forward_offline(self, audio, audio_len):
""" """

@ -16,7 +16,7 @@ from .deepspeech2 import DeepSpeech2Model
from paddlespeech.s2t.utils import dynamic_pip_install from paddlespeech.s2t.utils import dynamic_pip_install
try: try:
import swig_decoders import paddlespeech_ctcdecoders
except ImportError: except ImportError:
try: try:
package_name = 'paddlespeech_ctcdecoders' package_name = 'paddlespeech_ctcdecoders'

@ -164,24 +164,18 @@ class DeepSpeech2Model(nn.Layer):
return loss return loss
@paddle.no_grad() @paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method, def decode(self, audio, audio_len):
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8 # decoders only accept string encoded in utf-8
self.decoder.init_decode(
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
# Make sure the decoder has been initialized
eouts, eouts_len = self.encoder(audio, audio_len) eouts, eouts_len = self.encoder(audio, audio_len)
probs = self.decoder.softmax(eouts) probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs( batch_size = probs.shape[0]
probs.numpy(), eouts_len, vocab_list, decoding_method, self.decoder.reset_decoder(batch_size=batch_size)
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, self.decoder.next(probs, eouts_len)
cutoff_top_n, num_processes) trans_best, trans_beam = self.decoder.decode()
return trans_best
@classmethod @classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path): def from_pretrained(cls, dataloader, config, checkpoint_path):

@ -16,7 +16,7 @@ from .deepspeech2 import DeepSpeech2ModelOnline
from paddlespeech.s2t.utils import dynamic_pip_install from paddlespeech.s2t.utils import dynamic_pip_install
try: try:
import swig_decoders import paddlespeech_ctcdecoders
except ImportError: except ImportError:
try: try:
package_name = 'paddlespeech_ctcdecoders' package_name = 'paddlespeech_ctcdecoders'

@ -293,25 +293,17 @@ class DeepSpeech2ModelOnline(nn.Layer):
return loss return loss
@paddle.no_grad() @paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method, def decode(self, audio, audio_len):
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
# init once
# decoders only accept string encoded in utf-8 # decoders only accept string encoded in utf-8
self.decoder.init_decode( # Make sure the decoder has been initialized
beam_alpha=beam_alpha,
beam_beta=beam_beta,
lang_model_path=lang_model_path,
vocab_list=vocab_list,
decoding_method=decoding_method)
eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder( eouts, eouts_len, final_state_h_box, final_state_c_box = self.encoder(
audio, audio_len, None, None) audio, audio_len, None, None)
probs = self.decoder.softmax(eouts) probs = self.decoder.softmax(eouts)
return self.decoder.decode_probs( batch_size = probs.shape[0]
probs.numpy(), eouts_len, vocab_list, decoding_method, self.decoder.reset_decoder(batch_size=batch_size)
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, self.decoder.next(probs, eouts_len)
cutoff_top_n, num_processes) trans_best, trans_beam = self.decoder.decode()
return trans_best
@classmethod @classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path): def from_pretrained(cls, dataloader, config, checkpoint_path):

@ -32,7 +32,7 @@ from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.models.asr_interface import ASRInterface from paddlespeech.s2t.models.asr_interface import ASRInterface
from paddlespeech.s2t.modules.cmvn import GlobalCMVN from paddlespeech.s2t.modules.cmvn import GlobalCMVN
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoderBase
from paddlespeech.s2t.modules.decoder import TransformerDecoder from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder
@ -63,7 +63,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
vocab_size: int, vocab_size: int,
encoder: TransformerEncoder, encoder: TransformerEncoder,
decoder: TransformerDecoder, decoder: TransformerDecoder,
ctc: CTCDecoder, ctc: CTCDecoderBase,
ctc_weight: float=0.5, ctc_weight: float=0.5,
ignore_id: int=IGNORE_ID, ignore_id: int=IGNORE_ID,
lsm_weight: float=0.0, lsm_weight: float=0.0,
@ -840,7 +840,7 @@ class U2Model(U2DecodeModel):
model_conf = configs.get('model_conf', dict()) model_conf = configs.get('model_conf', dict())
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0) dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
grad_norm_type = model_conf.get('ctc_grad_norm_type', None) grad_norm_type = model_conf.get('ctc_grad_norm_type', None)
ctc = CTCDecoder( ctc = CTCDecoderBase(
odim=vocab_size, odim=vocab_size,
enc_n_units=encoder.output_size(), enc_n_units=encoder.output_size(),
blank_id=0, blank_id=0,

@ -28,7 +28,7 @@ from paddle import nn
from paddlespeech.s2t.frontend.utility import IGNORE_ID from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.modules.cmvn import GlobalCMVN from paddlespeech.s2t.modules.cmvn import GlobalCMVN
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoderBase
from paddlespeech.s2t.modules.decoder import TransformerDecoder from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder
@ -56,7 +56,7 @@ class U2STBaseModel(nn.Layer):
encoder: TransformerEncoder, encoder: TransformerEncoder,
st_decoder: TransformerDecoder, st_decoder: TransformerDecoder,
decoder: TransformerDecoder=None, decoder: TransformerDecoder=None,
ctc: CTCDecoder=None, ctc: CTCDecoderBase=None,
ctc_weight: float=0.0, ctc_weight: float=0.0,
asr_weight: float=0.0, asr_weight: float=0.0,
ignore_id: int=IGNORE_ID, ignore_id: int=IGNORE_ID,
@ -313,8 +313,7 @@ class U2STBaseModel(nn.Layer):
cache = [ cache = [
paddle.ones( paddle.ones(
(len(hyps), i - 1, hyp_cache.shape[-1]), (len(hyps), i - 1, hyp_cache.shape[-1]),
dtype=paddle.float32) dtype=paddle.float32) for hyp_cache in hyps[0]["cache"]
for hyp_cache in hyps[0]["cache"]
] ]
for j, hyp in enumerate(hyps): for j, hyp in enumerate(hyps):
ys[j, :] = paddle.to_tensor(hyp["yseq"]) ys[j, :] = paddle.to_tensor(hyp["yseq"])
@ -596,7 +595,7 @@ class U2STModel(U2STBaseModel):
model_conf = configs['model_conf'] model_conf = configs['model_conf']
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0) dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
grad_norm_type = model_conf.get('ctc_grad_norm_type', None) grad_norm_type = model_conf.get('ctc_grad_norm_type', None)
ctc = CTCDecoder( ctc = CTCDecoderBase(
odim=vocab_size, odim=vocab_size,
enc_n_units=encoder.output_size(), enc_n_units=encoder.output_size(),
blank_id=0, blank_id=0,

@ -25,17 +25,19 @@ from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
try: try:
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401 from paddlespeech.s2t.decoders.ctcdecoder import ctc_beam_search_decoding_batch # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_greedy_decoder # noqa: F401 from paddlespeech.s2t.decoders.ctcdecoder import ctc_greedy_decoding # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import Scorer # noqa: F401 from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder import CTCBeamSearchDecoder # noqa: F401
except ImportError: except ImportError:
try: try:
from paddlespeech.s2t.utils import dynamic_pip_install from paddlespeech.s2t.utils import dynamic_pip_install
package_name = 'paddlespeech_ctcdecoders' package_name = 'paddlespeech_ctcdecoders'
dynamic_pip_install.install(package_name) dynamic_pip_install.install(package_name)
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401 from paddlespeech.s2t.decoders.ctcdecoder import ctc_beam_search_decoding_batch # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import ctc_greedy_decoder # noqa: F401 from paddlespeech.s2t.decoders.ctcdecoder import ctc_greedy_decoding # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder.swig_wrapper import Scorer # noqa: F401 from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401
from paddlespeech.s2t.decoders.ctcdecoder import CTCBeamSearchDecoder # noqa: F401
except Exception as e: except Exception as e:
logger.info("paddlespeech_ctcdecoders not installed!") logger.info("paddlespeech_ctcdecoders not installed!")
@ -139,9 +141,11 @@ class CTCDecoder(CTCDecoderBase):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# CTCDecoder LM Score handle # CTCDecoder LM Score handle
self._ext_scorer = None self._ext_scorer = None
self.beam_search_decoder = None
def _decode_batch_greedy(self, probs_split, vocab_list): def _decode_batch_greedy_offline(self, probs_split, vocab_list):
"""Decode by best path for a batch of probs matrix input. """This function will be deprecated in future.
Decode by best path for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists :param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce. of prob vectors for one speech utterancce.
:param probs_split: List of matrix :param probs_split: List of matrix
@ -152,7 +156,7 @@ class CTCDecoder(CTCDecoderBase):
""" """
results = [] results = []
for i, probs in enumerate(probs_split): for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder( output_transcription = ctc_greedy_decoding(
probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id) probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id)
results.append(output_transcription) results.append(output_transcription)
return results return results
@ -194,10 +198,12 @@ class CTCDecoder(CTCDecoderBase):
logger.info("no language model provided, " logger.info("no language model provided, "
"decoding by pure beam search without scorer.") "decoding by pure beam search without scorer.")
def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, def _decode_batch_beam_search_offline(
beam_size, cutoff_prob, cutoff_top_n, self, probs_split, beam_alpha, beam_beta, beam_size, cutoff_prob,
vocab_list, num_processes): cutoff_top_n, vocab_list, num_processes):
"""Decode by beam search for a batch of probs matrix input. """
This function will be deprecated in future.
Decode by beam search for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists :param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce. of prob vectors for one speech utterancce.
:param probs_split: List of matrix :param probs_split: List of matrix
@ -226,7 +232,7 @@ class CTCDecoder(CTCDecoderBase):
# beam search decode # beam search decode
num_processes = min(num_processes, len(probs_split)) num_processes = min(num_processes, len(probs_split))
beam_search_results = ctc_beam_search_decoder_batch( beam_search_results = ctc_beam_search_decoding_batch(
probs_split=probs_split, probs_split=probs_split,
vocabulary=vocab_list, vocabulary=vocab_list,
beam_size=beam_size, beam_size=beam_size,
@ -239,30 +245,69 @@ class CTCDecoder(CTCDecoderBase):
results = [result[0][1] for result in beam_search_results] results = [result[0][1] for result in beam_search_results]
return results return results
def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, def init_decoder(self, batch_size, vocab_list, decoding_method,
decoding_method): lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_prob, cutoff_top_n, num_processes):
"""
init ctc decoders
Args:
batch_size(int): Batch size for input data
vocab_list (list): List of tokens in the vocabulary, for decoding
decoding_method (str): ctc_beam_search
lang_model_path (str): language model path
beam_alpha (float): beam_alpha
beam_beta (float): beam_beta
beam_size (int): beam_size
cutoff_prob (float): cutoff probability in beam search
cutoff_top_n (int): cutoff_top_n
num_processes (int): num_processes
Raises:
ValueError: when decoding_method not support.
Returns:
CTCBeamSearchDecoder
"""
self.batch_size = batch_size
self.vocab_list = vocab_list
self.decoding_method = decoding_method
self.beam_size = beam_size
self.cutoff_prob = cutoff_prob
self.cutoff_top_n = cutoff_top_n
self.num_processes = num_processes
if decoding_method == "ctc_beam_search": if decoding_method == "ctc_beam_search":
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
vocab_list) vocab_list)
if self.beam_search_decoder is None:
self.beam_search_decoder = self.get_decoder(
vocab_list, batch_size, beam_alpha, beam_beta, beam_size,
num_processes, cutoff_prob, cutoff_top_n)
return self.beam_search_decoder
elif decoding_method == "ctc_greedy":
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
vocab_list)
else:
raise ValueError(f"Not support: {decoding_method}")
def decode_probs(self, probs, logits_lens, vocab_list, decoding_method, def decode_probs_offline(self, probs, logits_lens, vocab_list,
lang_model_path, beam_alpha, beam_beta, beam_size, decoding_method, lang_model_path, beam_alpha,
cutoff_prob, cutoff_top_n, num_processes): beam_beta, beam_size, cutoff_prob, cutoff_top_n,
"""ctc decoding with probs. num_processes):
"""
This function will be deprecated in future.
ctc decoding with probs.
Args: Args:
probs (Tensor): activation after softmax probs (Tensor): activation after softmax
logits_lens (Tensor): audio output lens logits_lens (Tensor): audio output lens
vocab_list ([type]): [description] vocab_list (list): List of tokens in the vocabulary, for decoding
decoding_method ([type]): [description] decoding_method (str): ctc_beam_search
lang_model_path ([type]): [description] lang_model_path (str): language model path
beam_alpha ([type]): [description] beam_alpha (float): beam_alpha
beam_beta ([type]): [description] beam_beta (float): beam_beta
beam_size ([type]): [description] beam_size (int): beam_size
cutoff_prob ([type]): [description] cutoff_prob (float): cutoff probability in beam search
cutoff_top_n ([type]): [description] cutoff_top_n (int): cutoff_top_n
num_processes ([type]): [description] num_processes (int): num_processes
Raises: Raises:
ValueError: when decoding_method not support. ValueError: when decoding_method not support.
@ -270,13 +315,14 @@ class CTCDecoder(CTCDecoderBase):
Returns: Returns:
List[str]: transcripts. List[str]: transcripts.
""" """
logger.warn(
"This function will be deprecated in future: decode_probs_offline")
probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)] probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)]
if decoding_method == "ctc_greedy": if decoding_method == "ctc_greedy":
result_transcripts = self._decode_batch_greedy( result_transcripts = self._decode_batch_greedy_offline(
probs_split=probs_split, vocab_list=vocab_list) probs_split=probs_split, vocab_list=vocab_list)
elif decoding_method == "ctc_beam_search": elif decoding_method == "ctc_beam_search":
result_transcripts = self._decode_batch_beam_search( result_transcripts = self._decode_batch_beam_search_offline(
probs_split=probs_split, probs_split=probs_split,
beam_alpha=beam_alpha, beam_alpha=beam_alpha,
beam_beta=beam_beta, beam_beta=beam_beta,
@ -288,3 +334,136 @@ class CTCDecoder(CTCDecoderBase):
else: else:
raise ValueError(f"Not support: {decoding_method}") raise ValueError(f"Not support: {decoding_method}")
return result_transcripts return result_transcripts
def get_decoder(self, vocab_list, batch_size, beam_alpha, beam_beta,
beam_size, num_processes, cutoff_prob, cutoff_top_n):
"""
init get ctc decoder
Args:
vocab_list (list): List of tokens in the vocabulary, for decoding.
batch_size(int): Batch size for input data
beam_alpha (float): beam_alpha
beam_beta (float): beam_beta
beam_size (int): beam_size
num_processes (int): num_processes
cutoff_prob (float): cutoff probability in beam search
cutoff_top_n (int): cutoff_top_n
Raises:
ValueError: when decoding_method not support.
Returns:
CTCBeamSearchDecoder
"""
num_processes = min(num_processes, batch_size)
if self._ext_scorer is not None:
self._ext_scorer.reset_params(beam_alpha, beam_beta)
if self.decoding_method == "ctc_beam_search":
beam_search_decoder = CTCBeamSearchDecoder(
vocab_list, batch_size, beam_size, num_processes, cutoff_prob,
cutoff_top_n, self._ext_scorer, self.blank_id)
else:
raise ValueError(f"Not support: {decoding_method}")
return beam_search_decoder
def next(self, probs, logits_lens):
"""
Input probs into ctc decoder
Args:
probs (list(list(float))): probs for a batch of data
logits_lens (list(int)): logits lens for a batch of data
Raises:
Exception: when the ctc decoder is not initialized
ValueError: when decoding_method not support.
"""
if self.beam_search_decoder is None:
raise Exception(
"You need to initialize the beam_search_decoder firstly")
beam_search_decoder = self.beam_search_decoder
has_value = (logits_lens > 0).tolist()
has_value = [
"true" if has_value[i] is True else "false"
for i in range(len(has_value))
]
probs_split = [
probs[i, :l, :].tolist() if has_value[i] else probs[i].tolist()
for i, l in enumerate(logits_lens)
]
if self.decoding_method == "ctc_beam_search":
beam_search_decoder.next(probs_split, has_value)
else:
raise ValueError(f"Not support: {decoding_method}")
return
def decode(self):
"""
Get the decoding result
Raises:
Exception: when the ctc decoder is not initialized
ValueError: when decoding_method not support.
Returns:
results_best (list(str)): The best result for a batch of data
results_beam (list(list(str))): The beam search result for a batch of data
"""
if self.beam_search_decoder is None:
raise Exception(
"You need to initialize the beam_search_decoder firstly")
beam_search_decoder = self.beam_search_decoder
if self.decoding_method == "ctc_beam_search":
batch_beam_results = beam_search_decoder.decode()
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results]
results_best = [result[0][1] for result in batch_beam_results]
results_beam = [[trans[1] for trans in result]
for result in batch_beam_results]
else:
raise ValueError(f"Not support: {decoding_method}")
return results_best, results_beam
def reset_decoder(self,
batch_size=-1,
beam_size=-1,
num_processes=-1,
cutoff_prob=-1.0,
cutoff_top_n=-1):
if batch_size > 0:
self.batch_size = batch_size
if beam_size > 0:
self.beam_size = beam_size
if num_processes > 0:
self.num_processes = num_processes
if cutoff_prob > 0:
self.cutoff_prob = cutoff_prob
if cutoff_top_n > 0:
self.cutoff_top_n = cutoff_top_n
"""
Reset the decoder state
Args:
batch_size(int): Batch size for input data
beam_size (int): beam_size
num_processes (int): num_processes
cutoff_prob (float): cutoff probability in beam search
cutoff_top_n (int): cutoff_top_n
Raises:
Exception: when the ctc decoder is not initialized
"""
if self.beam_search_decoder is None:
raise Exception(
"You need to initialize the beam_search_decoder firstly")
self.beam_search_decoder.reset_state(
self.batch_size, self.beam_size, self.num_processes,
self.cutoff_prob, self.cutoff_top_n)
def del_decoder(self):
"""
Delete the decoder
"""
if self.beam_search_decoder is not None:
del self.beam_search_decoder
self.beam_search_decoder = None

@ -29,7 +29,8 @@
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::pair<double, std::string>> ctc_beam_search_decoding(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary, const std::vector<std::string> &vocabulary,
size_t beam_size, size_t beam_size,
@ -46,6 +47,8 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
"The shape of probs_seq does not match with " "The shape of probs_seq does not match with "
"the shape of the vocabulary"); "the shape of the vocabulary");
} }
// assign space id // assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE); auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
int space_id = it - vocabulary.begin(); int space_id = it - vocabulary.begin();
@ -206,7 +209,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std::vector<std::vector<std::pair<double, std::string>>> std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoding_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split, const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &vocabulary, const std::vector<std::string> &vocabulary,
size_t beam_size, size_t beam_size,
@ -224,7 +227,7 @@ ctc_beam_search_decoder_batch(
// enqueue the tasks of decoding // enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res; std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder, res.emplace_back(pool.enqueue(ctc_beam_search_decoding,
probs_split[i], probs_split[i],
vocabulary, vocabulary,
beam_size, beam_size,
@ -241,3 +244,364 @@ ctc_beam_search_decoder_batch(
} }
return batch_results; return batch_results;
} }
void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer) {
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict =
static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root->set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root->set_matcher(matcher);
}
}
void ctc_beam_search_decode_chunk(
PathTrie *root,
std::vector<PathTrie *> &prefixes,
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id) {
// dimension check
size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(),
// vocabulary.size() + 1,
vocabulary.size(),
"The shape of probs_seq does not match with "
"the shape of the vocabulary");
}
// assign space id
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) {
space_id = -2;
}
// init prefixes' root
//
// prefix search over time
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
auto &prob = probs_seq[time_step];
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
if (ext_scorer != nullptr) {
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(),
prefixes.begin() + num_prefixes,
prefix_compare);
min_cutoff = prefixes[num_prefixes - 1]->score +
std::log(prob[blank_id]) -
std::max(0.0, ext_scorer->beta);
full_beam = (num_prefixes == beam_size);
}
std::vector<std::pair<size_t, float>> log_prob_idx =
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
// loop over chars
for (size_t index = 0; index < log_prob_idx.size(); index++) {
auto c = log_prob_idx[index].first;
auto log_prob_c = log_prob_idx[index].second;
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
// blank
if (c == blank_id) {
prefix->log_prob_b_cur = log_sum_exp(
prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur =
log_sum_exp(prefix->log_prob_nb_cur,
log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) *
ext_scorer->alpha;
log_p += score;
log_p += ext_scorer->beta;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over vocabulary
prefixes.clear();
// update log probs
root->iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(),
prefixes.begin() + beam_size,
prefixes.end(),
prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
}
} // end of loop over time
return;
}
std::vector<std::pair<double, std::string>> get_decode_result(
std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary,
size_t beam_size,
Scorer *ext_scorer) {
auto it = std::find(vocabulary.begin(), vocabulary.end(), kSPACE);
int space_id = it - vocabulary.begin();
// if no space in vocabulary
if ((size_t)space_id >= vocabulary.size()) {
space_id = -2;
}
// score the last word of each prefix that doesn't end with space
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score =
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score += ext_scorer->beta;
prefix->score += score;
}
}
}
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
// compute aproximate ctc score as the return score, without affecting the
// return order of decoding result. To delete when decoder gets stable.
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score;
if (ext_scorer != nullptr) {
std::vector<int> output;
prefixes[i]->get_path_vec(output);
auto prefix_length = output.size();
auto words = ext_scorer->split_labels(output);
// remove word insert
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
// remove language model weight:
approx_ctc -=
(ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
}
prefixes[i]->approx_ctc = approx_ctc;
}
std::vector<std::pair<double, std::string>> res =
get_beam_search_result(prefixes, vocabulary, beam_size);
// pay back the last word of each prefix that doesn't end with space (for
// decoding by chunk)
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
auto prefix = prefixes[i];
if (!prefix->is_empty() && prefix->character != space_id) {
float score = 0.0;
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
score =
ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
score += ext_scorer->beta;
prefix->score -= score;
}
}
}
return res;
}
void free_storage(std::unique_ptr<CtcBeamSearchDecoderStorage> &storage) {
storage = nullptr;
}
CtcBeamSearchDecoderBatch::~CtcBeamSearchDecoderBatch() {}
CtcBeamSearchDecoderBatch::CtcBeamSearchDecoderBatch(
const std::vector<std::string> &vocabulary,
size_t batch_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id)
: batch_size(batch_size),
beam_size(beam_size),
num_processes(num_processes),
cutoff_prob(cutoff_prob),
cutoff_top_n(cutoff_top_n),
ext_scorer(ext_scorer),
blank_id(blank_id) {
VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!");
VALID_CHECK_GT(
this->num_processes, 0, "num_processes must be nonnegative!");
this->vocabulary = vocabulary;
for (size_t i = 0; i < batch_size; i++) {
this->decoder_storage_vector.push_back(
std::unique_ptr<CtcBeamSearchDecoderStorage>(
new CtcBeamSearchDecoderStorage()));
ctc_beam_search_decode_chunk_begin(
this->decoder_storage_vector[i]->root, ext_scorer);
}
};
/**
* Input
* probs_split: shape [B, T, D]
*/
void CtcBeamSearchDecoderBatch::next(
const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &has_value) {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
// thread pool
size_t num_has_value = 0;
for (int i = 0; i < has_value.size(); i++)
if (has_value[i] == "true") num_has_value += 1;
ThreadPool pool(std::min(num_processes, num_has_value));
// number of samples
size_t probs_num = probs_split.size();
VALID_CHECK_EQ(this->batch_size,
probs_num,
"The batch size of the current input data should be same "
"with the input data before");
// enqueue the tasks of decoding
std::vector<std::future<void>> res;
for (size_t i = 0; i < batch_size; ++i) {
if (has_value[i] == "true") {
res.emplace_back(pool.enqueue(
ctc_beam_search_decode_chunk,
std::ref(this->decoder_storage_vector[i]->root),
std::ref(this->decoder_storage_vector[i]->prefixes),
probs_split[i],
this->vocabulary,
this->beam_size,
this->cutoff_prob,
this->cutoff_top_n,
this->ext_scorer,
this->blank_id));
}
}
for (size_t i = 0; i < batch_size; ++i) {
res[i].get();
}
return;
};
/**
* Return
* batch_result: shape[B, beam_size,(-approx_ctc score, string)]
*/
std::vector<std::vector<std::pair<double, std::string>>>
CtcBeamSearchDecoderBatch::decode() {
VALID_CHECK_GT(
this->num_processes, 0, "num_processes must be nonnegative!");
// thread pool
ThreadPool pool(this->num_processes);
// number of samples
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (size_t i = 0; i < this->batch_size; ++i) {
res.emplace_back(
pool.enqueue(get_decode_result,
std::ref(this->decoder_storage_vector[i]->prefixes),
this->vocabulary,
this->beam_size,
this->ext_scorer));
}
// get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (size_t i = 0; i < this->batch_size; ++i) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
}
/**
* reset the state of ctcBeamSearchDecoderBatch
*/
void CtcBeamSearchDecoderBatch::reset_state(size_t batch_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n) {
this->batch_size = batch_size;
this->beam_size = beam_size;
this->num_processes = num_processes;
this->cutoff_prob = cutoff_prob;
this->cutoff_top_n = cutoff_top_n;
VALID_CHECK_GT(this->beam_size, 0, "beam_size must be greater than 0!");
VALID_CHECK_GT(
this->num_processes, 0, "num_processes must be nonnegative!");
// thread pool
ThreadPool pool(this->num_processes);
// number of samples
// enqueue the tasks of decoding
std::vector<std::future<void>> res;
size_t storage_size = decoder_storage_vector.size();
for (size_t i = 0; i < storage_size; i++) {
res.emplace_back(pool.enqueue(
free_storage, std::ref(this->decoder_storage_vector[i])));
}
for (size_t i = 0; i < storage_size; ++i) {
res[i].get();
}
std::vector<std::unique_ptr<CtcBeamSearchDecoderStorage>>().swap(
decoder_storage_vector);
for (size_t i = 0; i < this->batch_size; i++) {
this->decoder_storage_vector.push_back(
std::unique_ptr<CtcBeamSearchDecoderStorage>(
new CtcBeamSearchDecoderStorage()));
ctc_beam_search_decode_chunk_begin(
this->decoder_storage_vector[i]->root, this->ext_scorer);
}
}

@ -37,7 +37,7 @@
* A vector that each element is a pair of score and decoding result, * A vector that each element is a pair of score and decoding result,
* in desending order. * in desending order.
*/ */
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( std::vector<std::pair<double, std::string>> ctc_beam_search_decoding(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary, const std::vector<std::string> &vocabulary,
size_t beam_size, size_t beam_size,
@ -46,6 +46,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
Scorer *ext_scorer = nullptr, Scorer *ext_scorer = nullptr,
size_t blank_id = 0); size_t blank_id = 0);
/* CTC Beam Search Decoder for batch data /* CTC Beam Search Decoder for batch data
* Parameters: * Parameters:
@ -64,7 +65,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
* result for one audio sample. * result for one audio sample.
*/ */
std::vector<std::vector<std::pair<double, std::string>>> std::vector<std::vector<std::pair<double, std::string>>>
ctc_beam_search_decoder_batch( ctc_beam_search_decoding_batch(
const std::vector<std::vector<std::vector<double>>> &probs_split, const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &vocabulary, const std::vector<std::string> &vocabulary,
size_t beam_size, size_t beam_size,
@ -74,4 +75,101 @@ ctc_beam_search_decoder_batch(
Scorer *ext_scorer = nullptr, Scorer *ext_scorer = nullptr,
size_t blank_id = 0); size_t blank_id = 0);
/**
* Store the root and prefixes for decoder
*/
class CtcBeamSearchDecoderStorage {
public:
PathTrie *root = nullptr;
std::vector<PathTrie *> prefixes;
CtcBeamSearchDecoderStorage() {
// init prefixes' root
this->root = new PathTrie();
this->root->log_prob_b_prev = 0.0;
// The score of root is in log scale.Since the prob=1.0, the prob score
// in log scale is 0.0
this->root->score = root->log_prob_b_prev;
// std::vector<PathTrie *> prefixes;
this->prefixes.push_back(root);
};
~CtcBeamSearchDecoderStorage() {
if (root != nullptr) {
delete root;
root = nullptr;
}
};
};
/**
* The ctc beam search decoder, support batchsize >= 1
*/
class CtcBeamSearchDecoderBatch {
public:
CtcBeamSearchDecoderBatch(const std::vector<std::string> &vocabulary,
size_t batch_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id);
~CtcBeamSearchDecoderBatch();
void next(const std::vector<std::vector<std::vector<double>>> &probs_split,
const std::vector<std::string> &has_value);
std::vector<std::vector<std::pair<double, std::string>>> decode();
void reset_state(size_t batch_size,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n);
private:
std::vector<std::string> vocabulary;
size_t batch_size;
size_t beam_size;
size_t num_processes;
double cutoff_prob;
size_t cutoff_top_n;
Scorer *ext_scorer;
size_t blank_id;
std::vector<std::unique_ptr<CtcBeamSearchDecoderStorage>>
decoder_storage_vector;
};
/**
* function for chunk decoding
*/
void ctc_beam_search_decode_chunk(
PathTrie *root,
std::vector<PathTrie *> &prefixes,
const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer,
size_t blank_id);
std::vector<std::pair<double, std::string>> get_decode_result(
std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary,
size_t beam_size,
Scorer *ext_scorer);
/**
* free the CtcBeamSearchDecoderStorage
*/
void free_storage(std::unique_ptr<CtcBeamSearchDecoderStorage> &storage);
/**
* initialize the root
*/
void ctc_beam_search_decode_chunk_begin(PathTrie *root, Scorer *ext_scorer);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_

@ -15,7 +15,7 @@
#include "ctc_greedy_decoder.h" #include "ctc_greedy_decoder.h"
#include "decoder_utils.h" #include "decoder_utils.h"
std::string ctc_greedy_decoder( std::string ctc_greedy_decoding(
const std::vector<std::vector<double>> &probs_seq, const std::vector<std::vector<double>> &probs_seq,
const std::vector<std::string> &vocabulary, const std::vector<std::string> &vocabulary,
size_t blank_id) { size_t blank_id) {

@ -27,7 +27,7 @@
* Return: * Return:
* The decoding result in string * The decoding result in string
*/ */
std::string ctc_greedy_decoder( std::string ctc_greedy_decoding(
const std::vector<std::vector<double>>& probs_seq, const std::vector<std::vector<double>>& probs_seq,
const std::vector<std::string>& vocabulary, const std::vector<std::string>& vocabulary,
size_t blank_id); size_t blank_id);

@ -1,4 +1,4 @@
%module swig_decoders %module paddlespeech_ctcdecoders
%{ %{
#include "scorer.h" #include "scorer.h"
#include "ctc_greedy_decoder.h" #include "ctc_greedy_decoder.h"

@ -44,6 +44,7 @@ PathTrie::PathTrie() {
PathTrie::~PathTrie() { PathTrie::~PathTrie() {
for (auto child : children_) { for (auto child : children_) {
delete child.second; delete child.second;
child.second = nullptr;
} }
} }
@ -131,26 +132,26 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
void PathTrie::remove() { void PathTrie::remove() {
exists_ = false; exists_ = false;
if (children_.size() == 0) { if (children_.size() == 0) {
auto child = parent->children_.begin(); if (parent != nullptr) {
for (child = parent->children_.begin(); auto child = parent->children_.begin();
child != parent->children_.end(); for (child = parent->children_.begin();
++child) { child != parent->children_.end();
if (child->first == character) { ++child) {
parent->children_.erase(child); if (child->first == character) {
break; parent->children_.erase(child);
break;
}
}
if (parent->children_.size() == 0 && !parent->exists_) {
parent->remove();
} }
} }
if (parent->children_.size() == 0 && !parent->exists_) {
parent->remove();
}
delete this; delete this;
} }
} }
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
dictionary_ = dictionary; dictionary_ = dictionary;
dictionary_state_ = dictionary->Start(); dictionary_state_ = dictionary->Start();

@ -1,4 +1,5 @@
// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the "COPYING.LESSER.3"); // Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the
// "COPYING.LESSER.3");
#include "scorer.h" #include "scorer.h"

@ -1,4 +1,5 @@
// Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the "COPYING.LESSER.3"); // Licensed under GNU Lesser General Public License v3 (LGPLv3) (LGPL-3) (the
// "COPYING.LESSER.3");
#ifndef SCORER_H_ #ifndef SCORER_H_
#define SCORER_H_ #define SCORER_H_

@ -112,7 +112,7 @@ os.system('swig -python -c++ ./decoders.i')
decoders_module = [ decoders_module = [
Extension( Extension(
name='_swig_decoders', name='_paddlespeech_ctcdecoders',
sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'), sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'),
language='c++', language='c++',
include_dirs=[ include_dirs=[
@ -134,4 +134,4 @@ setup(
url="https://github.com/PaddlePaddle/PaddleSpeech", url="https://github.com/PaddlePaddle/PaddleSpeech",
license='Apache 2.0, GNU Lesser General Public License v3 (LGPLv3) (LGPL-3)', license='Apache 2.0, GNU Lesser General Public License v3 (LGPLv3) (LGPL-3)',
ext_modules=decoders_module, ext_modules=decoders_module,
py_modules=['swig_decoders']) py_modules=['paddlespeech_ctcdecoders'])

Loading…
Cancel
Save