|
|
@ -12,14 +12,14 @@
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "recognizer/u2_recognizer.h"
|
|
|
|
#include "common/base/thread_pool.h"
|
|
|
|
#include "common/base/thread_pool.h"
|
|
|
|
#include "common/utils/strings.h"
|
|
|
|
|
|
|
|
#include "common/utils/file_utils.h"
|
|
|
|
#include "common/utils/file_utils.h"
|
|
|
|
|
|
|
|
#include "common/utils/strings.h"
|
|
|
|
#include "decoder/param.h"
|
|
|
|
#include "decoder/param.h"
|
|
|
|
#include "frontend/wave-reader.h"
|
|
|
|
#include "frontend/wave-reader.h"
|
|
|
|
#include "nnet/u2_nnet.h"
|
|
|
|
|
|
|
|
#include "kaldi/util/table-types.h"
|
|
|
|
#include "kaldi/util/table-types.h"
|
|
|
|
#include "recognizer/u2_recognizer.h"
|
|
|
|
#include "nnet/u2_nnet.h"
|
|
|
|
|
|
|
|
|
|
|
|
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
|
|
|
|
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
|
|
|
|
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
|
|
|
DEFINE_string(result_wspecifier, "", "test result wspecifier");
|
|
|
@ -31,10 +31,10 @@ using std::string;
|
|
|
|
using std::vector;
|
|
|
|
using std::vector;
|
|
|
|
|
|
|
|
|
|
|
|
void SplitUtt(string wavlist_file,
|
|
|
|
void SplitUtt(string wavlist_file,
|
|
|
|
vector<vector<string>>* uttlists,
|
|
|
|
vector<vector<string>>* uttlists,
|
|
|
|
vector<vector<string>>* wavlists,
|
|
|
|
vector<vector<string>>* wavlists,
|
|
|
|
int njob) {
|
|
|
|
int njob) {
|
|
|
|
vector<string> wavlist;
|
|
|
|
vector<string> wavlist;
|
|
|
|
wavlists->resize(njob);
|
|
|
|
wavlists->resize(njob);
|
|
|
|
uttlists->resize(njob);
|
|
|
|
uttlists->resize(njob);
|
|
|
|
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
|
|
|
|
ppspeech::ReadFileToVector(wavlist_file, &wavlist);
|
|
|
@ -43,13 +43,13 @@ void SplitUtt(string wavlist_file,
|
|
|
|
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
|
|
|
|
vector<string> utt_wav = ppspeech::StrSplit(utt_str, " \t");
|
|
|
|
LOG(INFO) << utt_wav[0];
|
|
|
|
LOG(INFO) << utt_wav[0];
|
|
|
|
CHECK_EQ(utt_wav.size(), size_t(2));
|
|
|
|
CHECK_EQ(utt_wav.size(), size_t(2));
|
|
|
|
uttlists->at(idx % njob).push_back(utt_wav[0]);
|
|
|
|
uttlists->at(idx % njob).push_back(utt_wav[0]);
|
|
|
|
wavlists->at(idx % njob).push_back(utt_wav[1]);
|
|
|
|
wavlists->at(idx % njob).push_back(utt_wav[1]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void recognizer_func(const ppspeech::U2RecognizerResource& resource,
|
|
|
|
void recognizer_func(const ppspeech::U2RecognizerResource& resource,
|
|
|
|
std::shared_ptr<ppspeech::NnetBase> nnet,
|
|
|
|
std::shared_ptr<ppspeech::NnetBase> nnet,
|
|
|
|
std::vector<string> wavlist,
|
|
|
|
std::vector<string> wavlist,
|
|
|
|
std::vector<string> uttlist,
|
|
|
|
std::vector<string> uttlist,
|
|
|
|
std::vector<string>* results) {
|
|
|
|
std::vector<string>* results) {
|
|
|
@ -60,8 +60,8 @@ void recognizer_func(const ppspeech::U2RecognizerResource& resource,
|
|
|
|
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
|
|
|
|
int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate;
|
|
|
|
if (wavlist.empty()) return;
|
|
|
|
if (wavlist.empty()) return;
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<ppspeech::U2Recognizer> recognizer_ptr(
|
|
|
|
std::shared_ptr<ppspeech::U2Recognizer> recognizer_ptr =
|
|
|
|
new ppspeech::U2Recognizer(resource, nnet));
|
|
|
|
std::make_shared<ppspeech::U2Recognizer>(resource, nnet);
|
|
|
|
|
|
|
|
|
|
|
|
results->reserve(wavlist.size());
|
|
|
|
results->reserve(wavlist.size());
|
|
|
|
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
|
|
|
|
for (size_t idx = 0; idx < wavlist.size(); ++idx) {
|
|
|
@ -118,22 +118,22 @@ void recognizer_func(const ppspeech::U2RecognizerResource& resource,
|
|
|
|
result = " ";
|
|
|
|
result = " ";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
tot_decode_time += local_timer.Elapsed();
|
|
|
|
tot_decode_time += local_timer.Elapsed();
|
|
|
|
LOG(INFO) << utt << " " << result;
|
|
|
|
LOG(INFO) << utt << " " << result;
|
|
|
|
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
|
|
|
|
LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur
|
|
|
|
<< " cost: " << local_timer.Elapsed();
|
|
|
|
<< " cost: " << local_timer.Elapsed();
|
|
|
|
|
|
|
|
|
|
|
|
results->push_back(result);
|
|
|
|
results->push_back(result);
|
|
|
|
++num_done;
|
|
|
|
++num_done;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
recognizer_ptr->WaitFinished();
|
|
|
|
recognizer_ptr->WaitFinished();
|
|
|
|
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
|
|
|
|
LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done);
|
|
|
|
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
|
|
|
|
LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec";
|
|
|
|
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
|
|
|
|
LOG(INFO) << "total decode cost:" << tot_decode_time << " sec";
|
|
|
|
LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec";
|
|
|
|
LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec";
|
|
|
|
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
|
|
|
|
LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int main(int argc, char* argv[]) {
|
|
|
|
int main(int argc, char* argv[]) {
|
|
|
|
gflags::SetUsageMessage("Usage:");
|
|
|
|
gflags::SetUsageMessage("Usage:");
|
|
|
|
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
|
|
|
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
|
|
@ -157,14 +157,19 @@ int main(int argc, char* argv[]) {
|
|
|
|
vector<vector<string>> uttlist;
|
|
|
|
vector<vector<string>> uttlist;
|
|
|
|
vector<vector<string>> resultlist(njob);
|
|
|
|
vector<vector<string>> resultlist(njob);
|
|
|
|
vector<std::future<void>> futurelist;
|
|
|
|
vector<std::future<void>> futurelist;
|
|
|
|
std::thread threads[njob];
|
|
|
|
std::shared_ptr<ppspeech::U2Nnet> nnet(
|
|
|
|
std::shared_ptr<ppspeech::U2Nnet> nnet(new ppspeech::U2Nnet(resource.model_opts));
|
|
|
|
new ppspeech::U2Nnet(resource.model_opts));
|
|
|
|
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
|
|
|
|
SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob);
|
|
|
|
for (size_t i = 0; i < njob; ++i) {
|
|
|
|
for (size_t i = 0; i < njob; ++i) {
|
|
|
|
std::future<void> f = threadpool.enqueue(recognizer_func, resource, nnet->Clone(), wavlist[i], uttlist[i], &resultlist[i]);
|
|
|
|
std::future<void> f = threadpool.enqueue(recognizer_func,
|
|
|
|
|
|
|
|
resource,
|
|
|
|
|
|
|
|
nnet->Clone(),
|
|
|
|
|
|
|
|
wavlist[i],
|
|
|
|
|
|
|
|
uttlist[i],
|
|
|
|
|
|
|
|
&resultlist[i]);
|
|
|
|
futurelist.push_back(std::move(f));
|
|
|
|
futurelist.push_back(std::move(f));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < njob; ++i) {
|
|
|
|
for (size_t i = 0; i < njob; ++i) {
|
|
|
|
futurelist[i].get();
|
|
|
|
futurelist[i].get();
|
|
|
|
}
|
|
|
|
}
|
|
|
|