|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
"""Wrapper for various CTC decoders in SWIG."""
|
|
|
|
import swig_decoders
|
|
|
|
|
|
|
|
|
|
|
|
class Scorer(swig_decoders.Scorer):
|
|
|
|
"""Wrapper for Scorer.
|
|
|
|
|
|
|
|
:param alpha: Parameter associated with language model. Don't use
|
|
|
|
language model when alpha = 0.
|
|
|
|
:type alpha: float
|
|
|
|
:param beta: Parameter associated with word count. Don't use word
|
|
|
|
count when beta = 0.
|
|
|
|
:type beta: float
|
|
|
|
:model_path: Path to load language model.
|
|
|
|
:type model_path: str
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, alpha, beta, model_path, vocabulary):
|
|
|
|
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
|
|
|
|
|
|
|
|
|
|
|
|
def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
|
|
|
|
"""Wrapper for ctc best path decoder in swig.
|
|
|
|
|
|
|
|
:param probs_seq: 2-D list of probability distributions over each time
|
|
|
|
step, with each element being a list of normalized
|
|
|
|
probabilities over vocabulary and blank.
|
|
|
|
:type probs_seq: 2-D list
|
|
|
|
:param vocabulary: Vocabulary list.
|
|
|
|
:type vocabulary: list
|
|
|
|
:return: Decoding result string.
|
|
|
|
:rtype: str
|
|
|
|
"""
|
|
|
|
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary,
|
|
|
|
blank_id)
|
Support paddle 2.x (#538)
* 2.x model
* model test pass
* fix data
* fix soundfile with flac support
* one thread dataloader test pass
* export feasture size
add trainer and utils
add setup model and dataloader
update travis using Bionic dist
* add venv; test under venv
* fix unittest; train and valid
* add train and config
* add config and train script
* fix ctc cuda memcopy error
* fix imports
* fix train valid log
* fix dataset batch shuffle shift start from 1
fix rank_zero_only decreator error
close tensorboard when train over
add decoding config and code
* test process can run
* test with decoding
* test and infer with decoding
* fix infer
* fix ctc loss
lr schedule
sortagrad
logger
* aishell egs
* refactor train
add aishell egs
* fix dataset batch shuffle and add batch sampler log
print model parameter
* fix model and ctc
* sequence_mask make all inputs zeros, which cause grad be zero, this is a bug of LessThanOp
add grad clip by global norm
add model train test notebook
* ctc loss
remove run prefix
using ord value as text id
* using unk when training
compute_loss need text ids
ord id using in test mode, which compute wer/cer
* fix tester
* add lr_deacy
refactor code
* fix tools
* fix ci
add tune
fix gru model bugs
add dataset and model test
* fix decoding
* refactor repo
fix decoding
* fix musan and rir dataset
* refactor io, loss, conv, rnn, gradclip, model, utils
* fix ci and import
* refactor model
add export jit model
* add deploy bin and test it
* rm uselss egs
* add layer tools
* refactor socket server
new model from pretrain
* remve useless
* fix instability loss and grad nan or inf for librispeech training
* fix sampler
* fix libri train.sh
* fix doc
* add license on cpp
* fix doc
* fix libri script
* fix install
* clip 5 wer 7.39, clip 400 wer 7.54, 1.8 clip 400 baseline 7.49
4 years ago
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def ctc_beam_search_decoder(probs_seq,
|
|
|
|
vocabulary,
|
|
|
|
beam_size,
|
|
|
|
cutoff_prob=1.0,
|
|
|
|
cutoff_top_n=40,
|
|
|
|
ext_scoring_func=None,
|
|
|
|
blank_id=0):
|
|
|
|
"""Wrapper for the CTC Beam Search Decoder.
|
|
|
|
|
|
|
|
:param probs_seq: 2-D list of probability distributions over each time
|
|
|
|
step, with each element being a list of normalized
|
|
|
|
probabilities over vocabulary and blank.
|
|
|
|
:type probs_seq: 2-D list
|
|
|
|
:param vocabulary: Vocabulary list.
|
|
|
|
:type vocabulary: list
|
|
|
|
:param beam_size: Width for beam search.
|
|
|
|
:type beam_size: int
|
|
|
|
:param cutoff_prob: Cutoff probability in pruning,
|
|
|
|
default 1.0, no pruning.
|
|
|
|
:type cutoff_prob: float
|
|
|
|
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
|
|
|
characters with highest probs in vocabulary will be
|
|
|
|
used in beam search, default 40.
|
|
|
|
:type cutoff_top_n: int
|
|
|
|
:param ext_scoring_func: External scoring function for
|
|
|
|
partially decoded sentence, e.g. word count
|
|
|
|
or language model.
|
|
|
|
:type external_scoring_func: callable
|
|
|
|
:return: List of tuples of log probability and sentence as decoding
|
|
|
|
results, in descending order of the probability.
|
|
|
|
:rtype: list
|
|
|
|
"""
|
|
|
|
beam_results = swig_decoders.ctc_beam_search_decoder(
|
|
|
|
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
|
|
|
|
ext_scoring_func, blank_id)
|
|
|
|
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
|
|
|
|
return beam_results
|
|
|
|
|
|
|
|
|
|
|
|
def ctc_beam_search_decoder_batch(probs_split,
|
|
|
|
vocabulary,
|
|
|
|
beam_size,
|
|
|
|
num_processes,
|
|
|
|
cutoff_prob=1.0,
|
|
|
|
cutoff_top_n=40,
|
|
|
|
ext_scoring_func=None,
|
|
|
|
blank_id=0):
|
|
|
|
"""Wrapper for the batched CTC beam search decoder.
|
|
|
|
|
|
|
|
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
|
|
|
of probabilities used by ctc_beam_search_decoder().
|
|
|
|
:type probs_seq: 3-D list
|
|
|
|
:param vocabulary: Vocabulary list.
|
|
|
|
:type vocabulary: list
|
|
|
|
:param beam_size: Width for beam search.
|
|
|
|
:type beam_size: int
|
|
|
|
:param num_processes: Number of parallel processes.
|
|
|
|
:type num_processes: int
|
|
|
|
:param cutoff_prob: Cutoff probability in vocabulary pruning,
|
|
|
|
default 1.0, no pruning.
|
|
|
|
:type cutoff_prob: float
|
|
|
|
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
|
|
|
characters with highest probs in vocabulary will be
|
|
|
|
used in beam search, default 40.
|
|
|
|
:type cutoff_top_n: int
|
|
|
|
:param num_processes: Number of parallel processes.
|
|
|
|
:type num_processes: int
|
|
|
|
:param ext_scoring_func: External scoring function for
|
|
|
|
partially decoded sentence, e.g. word count
|
|
|
|
or language model.
|
|
|
|
:type external_scoring_function: callable
|
|
|
|
:return: List of tuples of log probability and sentence as decoding
|
|
|
|
results, in descending order of the probability.
|
|
|
|
:rtype: list
|
|
|
|
"""
|
|
|
|
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
|
|
|
|
|
|
|
|
batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch(
|
|
|
|
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
|
|
|
|
cutoff_top_n, ext_scoring_func, blank_id)
|
|
|
|
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
|
|
|
|
for beam_results in batch_beam_results]
|
|
|
|
return batch_beam_results
|