append some changes

pull/2/head
Yibing Liu 7 years ago
commit d77eec2bfe

@ -18,7 +18,7 @@ std::string ctc_greedy_decoder(
const std::vector<std::string> &vocabulary) { const std::vector<std::string> &vocabulary) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; i++) { for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(), VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1, vocabulary.size() + 1,
"The shape of probs_seq does not match with " "The shape of probs_seq does not match with "
@ -28,7 +28,7 @@ std::string ctc_greedy_decoder(
size_t blank_id = vocabulary.size(); size_t blank_id = vocabulary.size();
std::vector<size_t> max_idx_vec; std::vector<size_t> max_idx_vec;
for (size_t i = 0; i < num_time_steps; i++) { for (size_t i = 0; i < num_time_steps; ++i) {
double max_prob = 0.0; double max_prob = 0.0;
size_t max_idx = 0; size_t max_idx = 0;
for (size_t j = 0; j < probs_seq[i].size(); j++) { for (size_t j = 0; j < probs_seq[i].size(); j++) {
@ -41,14 +41,14 @@ std::string ctc_greedy_decoder(
} }
std::vector<size_t> idx_vec; std::vector<size_t> idx_vec;
for (size_t i = 0; i < max_idx_vec.size(); i++) { for (size_t i = 0; i < max_idx_vec.size(); ++i) {
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
idx_vec.push_back(max_idx_vec[i]); idx_vec.push_back(max_idx_vec[i]);
} }
} }
std::string best_path_result; std::string best_path_result;
for (size_t i = 0; i < idx_vec.size(); i++) { for (size_t i = 0; i < idx_vec.size(); ++i) {
if (idx_vec[i] != blank_id) { if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[idx_vec[i]]; best_path_result += vocabulary[idx_vec[i]];
} }
@ -65,7 +65,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
Scorer *ext_scorer) { Scorer *ext_scorer) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; i++) { for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(), VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1, vocabulary.size() + 1,
"The shape of probs_seq does not match with " "The shape of probs_seq does not match with "
@ -111,7 +111,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
for (size_t time_step = 0; time_step < num_time_steps; time_step++) { for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
std::vector<double> prob = probs_seq[time_step]; std::vector<double> prob = probs_seq[time_step];
std::vector<std::pair<int, double>> prob_idx; std::vector<std::pair<int, double>> prob_idx;
for (size_t i = 0; i < prob.size(); i++) { for (size_t i = 0; i < prob.size(); ++i) {
prob_idx.push_back(std::pair<int, double>(i, prob[i])); prob_idx.push_back(std::pair<int, double>(i, prob[i]));
} }
@ -134,7 +134,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
if (cutoff_prob < 1.0) { if (cutoff_prob < 1.0) {
double cum_prob = 0.0; double cum_prob = 0.0;
cutoff_len = 0; cutoff_len = 0;
for (size_t i = 0; i < prob_idx.size(); i++) { for (size_t i = 0; i < prob_idx.size(); ++i) {
cum_prob += prob_idx[i].second; cum_prob += prob_idx[i].second;
cutoff_len += 1; cutoff_len += 1;
if (cum_prob >= cutoff_prob) break; if (cum_prob >= cutoff_prob) break;
@ -145,7 +145,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
prob_idx.begin(), prob_idx.begin() + cutoff_len); prob_idx.begin(), prob_idx.begin() + cutoff_len);
} }
std::vector<std::pair<size_t, float>> log_prob_idx; std::vector<std::pair<size_t, float>> log_prob_idx;
for (size_t i = 0; i < cutoff_len; i++) { for (size_t i = 0; i < cutoff_len; ++i) {
log_prob_idx.push_back(std::pair<int, float>( log_prob_idx.push_back(std::pair<int, float>(
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
} }
@ -155,7 +155,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
auto c = log_prob_idx[index].first; auto c = log_prob_idx[index].first;
float log_prob_c = log_prob_idx[index].second; float log_prob_c = log_prob_idx[index].second;
for (size_t i = 0; i < prefixes.size() && i < beam_size; i++) { for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i]; auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) { if (full_beam && log_prob_c + prefix->score < min_cutoff) {
@ -222,14 +222,14 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
prefixes.end(), prefixes.end(),
prefix_compare); prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); i++) { for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove(); prefixes[i]->remove();
} }
} }
} // end of loop over time } // end of loop over time
// compute aproximate ctc score as the return score // compute aproximate ctc score as the return score
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) { for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score; double approx_ctc = prefixes[i]->score;
if (ext_scorer != nullptr) { if (ext_scorer != nullptr) {
@ -249,14 +249,14 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
// allow for the post processing // allow for the post processing
std::vector<PathTrie *> space_prefixes; std::vector<PathTrie *> space_prefixes;
if (space_prefixes.empty()) { if (space_prefixes.empty()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) { for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
space_prefixes.push_back(prefixes[i]); space_prefixes.push_back(prefixes[i]);
} }
} }
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::vector<std::pair<double, std::string>> output_vecs; std::vector<std::pair<double, std::string>> output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) { for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
std::vector<int> output; std::vector<int> output;
space_prefixes[i]->get_path_vec(output); space_prefixes[i]->get_path_vec(output);
// convert index to string // convert index to string
@ -301,7 +301,7 @@ ctc_beam_search_decoder_batch(
// enqueue the tasks of decoding // enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res; std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (size_t i = 0; i < batch_size; i++) { for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder, res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
probs_split[i], probs_split[i],
beam_size, beam_size,
@ -313,7 +313,7 @@ ctc_beam_search_decoder_batch(
// get decoding results // get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results; std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (size_t i = 0; i < batch_size; i++) { for (size_t i = 0; i < batch_size; ++i) {
batch_results.emplace_back(res[i].get()); batch_results.emplace_back(res[i].get());
} }
return batch_results; return batch_results;

@ -0,0 +1,19 @@
#! /usr/bin/bash
source ../../utils/utility.sh
URL='http://cloud.dlnel.org/filepub/?uuid=6c83b9d8-3255-4adf-9726-0fe0be3d0274'
MD5=28521a58552885a81cf92a1e9b133a71
TARGET=./aishell_model.tar.gz
echo "Download Aishell model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download Aishell model!"
exit 1
fi
tar -zxvf $TARGET
exit 0

@ -2,9 +2,8 @@
source ../../utils/utility.sh source ../../utils/utility.sh
# TODO: add urls URL='http://cloud.dlnel.org/filepub/?uuid=17404caf-cf19-492f-9707-1fad07c19aae'
URL='to-be-added' MD5=ea5024a457a91179472f6dfee60e053d
MD5=5b4af224b26c1dc4dd972b7d32f2f52a
TARGET=./librispeech_model.tar.gz TARGET=./librispeech_model.tar.gz

@ -0,0 +1,18 @@
#! /usr/bin/bash
source ../../utils/utility.sh
URL=http://cloud.dlnel.org/filepub/?uuid=d21861e4-4ed6-45bb-ad8e-ae417a43195e
MD5="29e02312deb2e59b3c8686c7966d4fe3"
TARGET=./zh_giga.no_cna_cmn.prune01244.klm
echo "Download language model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
fi
exit 0

@ -2,4 +2,3 @@ scipy==0.13.1
resampy==0.1.5 resampy==0.1.5
SoundFile==0.9.0.post1 SoundFile==0.9.0.post1
python_speech_features python_speech_features
https://github.com/luotao1/kenlm/archive/master.zip

@ -11,10 +11,9 @@ download() {
fi fi
fi fi
wget -c $URL -P `dirname "$TARGET"` wget -c $URL -O "$TARGET"
md5_result=`md5sum $TARGET | awk -F[' '] '{print $1}'` md5_result=`md5sum $TARGET | awk -F[' '] '{print $1}'`
if [ ! $MD5 == $md5_result ]; then if [ ! $MD5 == $md5_result ]; then
echo "Fail to download the language model!"
return 1 return 1
fi fi
} }

Loading…
Cancel
Save