You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
113 lines
4.4 KiB
113 lines
4.4 KiB
"""Wrapper for various CTC decoders in SWIG."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import swig_decoders
|
|
|
|
|
|
class Scorer(swig_decoders.Scorer):
|
|
"""Wrapper for Scorer.
|
|
|
|
:param alpha: Parameter associated with language model. Don't use
|
|
language model when alpha = 0.
|
|
:type alpha: float
|
|
:param beta: Parameter associated with word count. Don't use word
|
|
count when beta = 0.
|
|
:type beta: float
|
|
:model_path: Path to load language model.
|
|
:type model_path: basestring
|
|
"""
|
|
|
|
def __init__(self, alpha, beta, model_path):
|
|
swig_decoders.Scorer.__init__(self, alpha, beta, model_path)
|
|
|
|
|
|
def ctc_best_path_decoder(probs_seq, vocabulary):
|
|
"""Wrapper for ctc best path decoder in swig.
|
|
|
|
:param probs_seq: 2-D list of probability distributions over each time
|
|
step, with each element being a list of normalized
|
|
probabilities over vocabulary and blank.
|
|
:type probs_seq: 2-D list
|
|
:param vocabulary: Vocabulary list.
|
|
:type vocabulary: list
|
|
:return: Decoding result string.
|
|
:rtype: basestring
|
|
"""
|
|
return swig_decoders.ctc_best_path_decoder(probs_seq.tolist(), vocabulary)
|
|
|
|
|
|
def ctc_beam_search_decoder(probs_seq,
|
|
beam_size,
|
|
vocabulary,
|
|
blank_id,
|
|
cutoff_prob=1.0,
|
|
ext_scoring_func=None):
|
|
"""Wrapper for the CTC Beam Search Decoder.
|
|
|
|
:param probs_seq: 2-D list of probability distributions over each time
|
|
step, with each element being a list of normalized
|
|
probabilities over vocabulary and blank.
|
|
:type probs_seq: 2-D list
|
|
:param beam_size: Width for beam search.
|
|
:type beam_size: int
|
|
:param vocabulary: Vocabulary list.
|
|
:type vocabulary: list
|
|
:param blank_id: ID of blank.
|
|
:type blank_id: int
|
|
:param cutoff_prob: Cutoff probability in pruning,
|
|
default 1.0, no pruning.
|
|
:type cutoff_prob: float
|
|
:param ext_scoring_func: External scoring function for
|
|
partially decoded sentence, e.g. word count
|
|
or language model.
|
|
:type external_scoring_func: callable
|
|
:return: List of tuples of log probability and sentence as decoding
|
|
results, in descending order of the probability.
|
|
:rtype: list
|
|
"""
|
|
return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), beam_size,
|
|
vocabulary, blank_id,
|
|
cutoff_prob, ext_scoring_func)
|
|
|
|
|
|
def ctc_beam_search_decoder_batch(probs_split,
|
|
beam_size,
|
|
vocabulary,
|
|
blank_id,
|
|
num_processes,
|
|
cutoff_prob=1.0,
|
|
ext_scoring_func=None):
|
|
"""Wrapper for the batched CTC beam search decoder.
|
|
|
|
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
|
of probabilities used by ctc_beam_search_decoder().
|
|
:type probs_seq: 3-D list
|
|
:param beam_size: Width for beam search.
|
|
:type beam_size: int
|
|
:param vocabulary: Vocabulary list.
|
|
:type vocabulary: list
|
|
:param blank_id: ID of blank.
|
|
:type blank_id: int
|
|
:param num_processes: Number of parallel processes.
|
|
:type num_processes: int
|
|
:param cutoff_prob: Cutoff probability in pruning,
|
|
default 1.0, no pruning.
|
|
:param num_processes: Number of parallel processes.
|
|
:type num_processes: int
|
|
:type cutoff_prob: float
|
|
:param ext_scoring_func: External scoring function for
|
|
partially decoded sentence, e.g. word count
|
|
or language model.
|
|
:type external_scoring_function: callable
|
|
:return: List of tuples of log probability and sentence as decoding
|
|
results, in descending order of the probability.
|
|
:rtype: list
|
|
"""
|
|
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
|
|
|
|
return swig_decoders.ctc_beam_search_decoder_batch(
|
|
probs_split, beam_size, vocabulary, blank_id, num_processes,
|
|
cutoff_prob, ext_scoring_func)
|