@ -17,41 +17,38 @@ std::string ctc_greedy_decoder(
const std : : vector < std : : vector < double > > & probs_seq ,
const std : : vector < std : : string > & vocabulary ) {
// dimension check
int num_time_steps = probs_seq . size ( ) ;
for ( int i = 0 ; i < num_time_steps ; i + + ) {
if ( probs_seq [ i ] . size ( ) ! = vocabulary . size ( ) + 1 ) {
std : : cout < < " The shape of probs_seq does not match "
< < " with the shape of the vocabulary! " < < std : : endl ;
exit ( 1 ) ;
}
size_t num_time_steps = probs_seq . size ( ) ;
for ( size_t i = 0 ; i < num_time_steps ; i + + ) {
VALID_CHECK_EQ ( probs_seq [ i ] . size ( ) ,
vocabulary . size ( ) + 1 ,
" The shape of probs_seq does not match with "
" the shape of the vocabulary " ) ;
}
in t blank_id = vocabulary . size ( ) ;
size_ t blank_id = vocabulary . size ( ) ;
std : : vector < in t> max_idx_vec ;
double max_prob = 0.0 ;
int max_idx = 0;
for ( int i = 0 ; i < num_time_steps ; i + + ) {
for ( in t j = 0 ; j < probs_seq [ i ] . size ( ) ; j + + ) {
std : : vector < size_ t> max_idx_vec ;
for ( size_t i = 0 ; i < num_time_steps ; i + + ) {
double max_prob = 0. 0;
size_t max_idx = 0 ;
for ( size_ t j = 0 ; j < probs_seq [ i ] . size ( ) ; j + + ) {
if ( max_prob < probs_seq [ i ] [ j ] ) {
max_idx = j ;
max_prob = probs_seq [ i ] [ j ] ;
}
}
max_idx_vec . push_back ( max_idx ) ;
max_prob = 0.0 ;
max_idx = 0 ;
}
std : : vector < in t> idx_vec ;
for ( in t i = 0 ; i < max_idx_vec . size ( ) ; i + + ) {
std : : vector < size_ t> idx_vec ;
for ( size_ t i = 0 ; i < max_idx_vec . size ( ) ; i + + ) {
if ( ( i = = 0 ) | | ( ( i > 0 ) & & max_idx_vec [ i ] ! = max_idx_vec [ i - 1 ] ) ) {
idx_vec . push_back ( max_idx_vec [ i ] ) ;
}
}
std : : string best_path_result ;
for ( in t i = 0 ; i < idx_vec . size ( ) ; i + + ) {
for ( size_ t i = 0 ; i < idx_vec . size ( ) ; i + + ) {
if ( idx_vec [ i ] ! = blank_id ) {
best_path_result + = vocabulary [ idx_vec [ i ] ] ;
}
@ -61,29 +58,24 @@ std::string ctc_greedy_decoder(
std : : vector < std : : pair < double , std : : string > > ctc_beam_search_decoder (
const std : : vector < std : : vector < double > > & probs_seq ,
in t beam_size ,
const size_ t beam_size ,
std : : vector < std : : string > vocabulary ,
int blank_id ,
double cutoff_prob ,
int cutoff_top_n ,
Scorer * extscorer ) {
const double cutoff_prob ,
const size_t cutoff_top_n ,
Scorer * ext_scorer ) {
// dimension check
size_t num_time_steps = probs_seq . size ( ) ;
for ( int i = 0 ; i < num_time_steps ; i + + ) {
if ( probs_seq [ i ] . size ( ) ! = vocabulary . size ( ) + 1 ) {
std : : cout < < " The shape of probs_seq does not match "
< < " with the shape of the vocabulary! " < < std : : endl ;
exit ( 1 ) ;
}
for ( size_t i = 0 ; i < num_time_steps ; i + + ) {
VALID_CHECK_EQ ( probs_seq [ i ] . size ( ) ,
vocabulary . size ( ) + 1 ,
" The shape of probs_seq does not match with "
" the shape of the vocabulary " ) ;
}
// blank_id check
if ( blank_id > vocabulary . size ( ) ) {
std : : cout < < " Invalid blank_id! " < < std : : endl ;
exit ( 1 ) ;
}
// assign blank id
size_t blank_id = vocabulary . size ( ) ;
// assign space ID
// assign space id
std : : vector < std : : string > : : iterator it =
std : : find ( vocabulary . begin ( ) , vocabulary . end ( ) , " " ) ;
int space_id = it - vocabulary . begin ( ) ;
@ -98,16 +90,16 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std : : vector < PathTrie * > prefixes ;
prefixes . push_back ( & root ) ;
if ( ext scorer ! = nullptr ) {
if ( ext scorer- > is_char_map_empty ( ) ) {
ext scorer- > set_char_map ( vocabulary ) ;
if ( ext _ scorer ! = nullptr ) {
if ( ext _ scorer- > is_char_map_empty ( ) ) {
ext _ scorer- > set_char_map ( vocabulary ) ;
}
if ( ! ext scorer- > is_character_based ( ) ) {
if ( ext scorer- > dictionary = = nullptr ) {
if ( ! ext _ scorer- > is_character_based ( ) ) {
if ( ext _ scorer- > dictionary = = nullptr ) {
// fill dictionary for fst with space
ext scorer- > fill_dictionary ( true ) ;
ext _ scorer- > fill_dictionary ( true ) ;
}
auto fst_dict = static_cast < fst : : StdVectorFst * > ( ext scorer- > dictionary ) ;
auto fst_dict = static_cast < fst : : StdVectorFst * > ( ext _ scorer- > dictionary ) ;
fst : : StdVectorFst * dict_ptr = fst_dict - > Copy ( true ) ;
root . set_dictionary ( dict_ptr ) ;
auto matcher = std : : make_shared < FSTMATCH > ( * dict_ptr , fst : : MATCH_INPUT ) ;
@ -116,33 +108,33 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
}
// prefix search over time
for ( in t time_step = 0 ; time_step < num_time_steps ; time_step + + ) {
for ( size_ t time_step = 0 ; time_step < num_time_steps ; time_step + + ) {
std : : vector < double > prob = probs_seq [ time_step ] ;
std : : vector < std : : pair < int , double > > prob_idx ;
for ( in t i = 0 ; i < prob . size ( ) ; i + + ) {
for ( size_ t i = 0 ; i < prob . size ( ) ; i + + ) {
prob_idx . push_back ( std : : pair < int , double > ( i , prob [ i ] ) ) ;
}
float min_cutoff = - NUM_FLT_INF ;
bool full_beam = false ;
if ( ext scorer ! = nullptr ) {
in t num_prefixes = std : : min ( ( int ) prefixes . size ( ) , beam_size ) ;
if ( ext _ scorer ! = nullptr ) {
size_ t num_prefixes = std : : min ( prefixes . size ( ) , beam_size ) ;
std : : sort (
prefixes . begin ( ) , prefixes . begin ( ) + num_prefixes , prefix_compare ) ;
min_cutoff = prefixes [ num_prefixes - 1 ] - > score + log ( prob [ blank_id ] ) -
std : : max ( 0.0 , ext scorer- > beta ) ;
std : : max ( 0.0 , ext _ scorer- > beta ) ;
full_beam = ( num_prefixes = = beam_size ) ;
}
// pruning of vacobulary
in t cutoff_len = prob . size ( ) ;
size_ t cutoff_len = prob . size ( ) ;
if ( cutoff_prob < 1.0 | | cutoff_top_n < prob . size ( ) ) {
std : : sort (
prob_idx . begin ( ) , prob_idx . end ( ) , pair_comp_second_rev < int , double > ) ;
if ( cutoff_prob < 1.0 ) {
double cum_prob = 0.0 ;
cutoff_len = 0 ;
for ( in t i = 0 ; i < prob_idx . size ( ) ; i + + ) {
for ( size_ t i = 0 ; i < prob_idx . size ( ) ; i + + ) {
cum_prob + = prob_idx [ i ] . second ;
cutoff_len + = 1 ;
if ( cum_prob > = cutoff_prob ) break ;
@ -152,18 +144,18 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
prob_idx = std : : vector < std : : pair < int , double > > (
prob_idx . begin ( ) , prob_idx . begin ( ) + cutoff_len ) ;
}
std : : vector < std : : pair < in t, float > > log_prob_idx ;
for ( in t i = 0 ; i < cutoff_len ; i + + ) {
std : : vector < std : : pair < size_ t, float > > log_prob_idx ;
for ( size_ t i = 0 ; i < cutoff_len ; i + + ) {
log_prob_idx . push_back ( std : : pair < int , float > (
prob_idx [ i ] . first , log ( prob_idx [ i ] . second + NUM_FLT_MIN ) ) ) ;
}
// loop over chars
for ( in t index = 0 ; index < log_prob_idx . size ( ) ; index + + ) {
for ( size_ t index = 0 ; index < log_prob_idx . size ( ) ; index + + ) {
auto c = log_prob_idx [ index ] . first ;
float log_prob_c = log_prob_idx [ index ] . second ;
for ( in t i = 0 ; i < prefixes . size ( ) & & i < beam_size ; i + + ) {
for ( size_ t i = 0 ; i < prefixes . size ( ) & & i < beam_size ; i + + ) {
auto prefix = prefixes [ i ] ;
if ( full_beam & & log_prob_c + prefix - > score < min_cutoff ) {
@ -194,12 +186,12 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
}
// language model scoring
if ( ext scorer ! = nullptr & &
( c = = space_id | | ext scorer- > is_character_based ( ) ) ) {
if ( ext _ scorer ! = nullptr & &
( c = = space_id | | ext _ scorer- > is_character_based ( ) ) ) {
PathTrie * prefix_toscore = nullptr ;
// skip scoring the space
if ( ext scorer- > is_character_based ( ) ) {
if ( ext _ scorer- > is_character_based ( ) ) {
prefix_toscore = prefix_new ;
} else {
prefix_toscore = prefix ;
@ -207,11 +199,11 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
double score = 0.0 ;
std : : vector < std : : string > ngram ;
ngram = ext scorer- > make_ngram ( prefix_toscore ) ;
score = ext scorer- > get_log_cond_prob ( ngram ) * ext scorer- > alpha ;
ngram = ext _ scorer- > make_ngram ( prefix_toscore ) ;
score = ext _ scorer- > get_log_cond_prob ( ngram ) * ext _ scorer- > alpha ;
log_p + = score ;
log_p + = ext scorer- > beta ;
log_p + = ext _ scorer- > beta ;
}
prefix_new - > log_prob_nb_cur =
log_sum_exp ( prefix_new - > log_prob_nb_cur , log_p ) ;
@ -240,15 +232,15 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
for ( size_t i = 0 ; i < beam_size & & i < prefixes . size ( ) ; i + + ) {
double approx_ctc = prefixes [ i ] - > score ;
if ( ext scorer ! = nullptr ) {
if ( ext _ scorer ! = nullptr ) {
std : : vector < int > output ;
prefixes [ i ] - > get_path_vec ( output ) ;
size_t prefix_length = output . size ( ) ;
auto words = ext scorer- > split_labels ( output ) ;
auto words = ext _ scorer- > split_labels ( output ) ;
// remove word insert
approx_ctc = approx_ctc - prefix_length * ext scorer- > beta ;
approx_ctc = approx_ctc - prefix_length * ext _ scorer- > beta ;
// remove language model weight:
approx_ctc - = ( ext scorer- > get_sent_log_prob ( words ) ) * ext scorer- > alpha ;
approx_ctc - = ( ext _ scorer- > get_sent_log_prob ( words ) ) * ext _ scorer- > alpha ;
}
prefixes [ i ] - > approx_ctc = approx_ctc ;
@ -269,7 +261,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
space_prefixes [ i ] - > get_path_vec ( output ) ;
// convert index to string
std : : string output_str ;
for ( in t j = 0 ; j < output . size ( ) ; j + + ) {
for ( size_ t j = 0 ; j < output . size ( ) ; j + + ) {
output_str + = vocabulary [ output [ j ] ] ;
}
std : : pair < double , std : : string > output_pair ( - space_prefixes [ i ] - > approx_ctc ,
@ -283,49 +275,45 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
std : : vector < std : : vector < std : : pair < double , std : : string > > >
ctc_beam_search_decoder_batch (
const std : : vector < std : : vector < std : : vector < double > > > & probs_split ,
in t beam_size ,
const size_ t beam_size ,
const std : : vector < std : : string > & vocabulary ,
int blank_id ,
int num_processes ,
double cutoff_prob ,
int cutoff_top_n ,
Scorer * extscorer ) {
if ( num_processes < = 0 ) {
std : : cout < < " num_processes must be nonnegative! " < < std : : endl ;
exit ( 1 ) ;
}
const size_t num_processes ,
const double cutoff_prob ,
const size_t cutoff_top_n ,
Scorer * ext_scorer ) {
VALID_CHECK_GT ( num_processes , 0 , " num_processes must be nonnegative! " ) ;
// thread pool
ThreadPool pool ( num_processes ) ;
// number of samples
in t batch_size = probs_split . size ( ) ;
size_t batch_size = probs_split . size ( ) ;
// scorer filling up
if ( ext scorer ! = nullptr ) {
if ( ext scorer- > is_char_map_empty ( ) ) {
ext scorer- > set_char_map ( vocabulary ) ;
if ( ext _ scorer ! = nullptr ) {
if ( ext _ scorer- > is_char_map_empty ( ) ) {
ext _ scorer- > set_char_map ( vocabulary ) ;
}
if ( ! extscorer - > is_character_based ( ) & & extscorer - > dictionary = = nullptr ) {
if ( ! ext_scorer - > is_character_based ( ) & &
ext_scorer - > dictionary = = nullptr ) {
// init dictionary
ext scorer- > fill_dictionary ( true ) ;
ext _ scorer- > fill_dictionary ( true ) ;
}
}
// enqueue the tasks of decoding
std : : vector < std : : future < std : : vector < std : : pair < double , std : : string > > > > res ;
for ( in t i = 0 ; i < batch_size ; i + + ) {
for ( size_ t i = 0 ; i < batch_size ; i + + ) {
res . emplace_back ( pool . enqueue ( ctc_beam_search_decoder ,
probs_split [ i ] ,
beam_size ,
vocabulary ,
blank_id ,
cutoff_prob ,
cutoff_top_n ,
ext scorer) ) ;
ext _ scorer) ) ;
}
// get decoding results
std : : vector < std : : vector < std : : pair < double , std : : string > > > batch_results ;
for ( in t i = 0 ; i < batch_size ; i + + ) {
for ( size_ t i = 0 ; i < batch_size ; i + + ) {
batch_results . emplace_back ( res [ i ] . get ( ) ) ;
}
return batch_results ;