seprate recognizer; NnetBase as base class

pull/2524/head
Hui Zhang 2 years ago
parent fddcd36fa0
commit 99b3632d4d

@ -32,6 +32,12 @@ ${CMAKE_CURRENT_SOURCE_DIR}/decoder
) )
add_subdirectory(decoder) add_subdirectory(decoder)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/recognizer
)
add_subdirectory(recognizer)
include_directories( include_directories(
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/protocol ${CMAKE_CURRENT_SOURCE_DIR}/protocol

@ -1,28 +1,24 @@
project(decoder)
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
set(decoder_src ) set(srcs)
if (USING_DS2) if (USING_DS2)
list(APPEND decoder_src list(APPEND srcs
ctc_decoders/decoder_utils.cpp ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp ctc_decoders/scorer.cpp
ctc_beam_search_decoder.cc ctc_beam_search_decoder.cc
ctc_tlg_decoder.cc ctc_tlg_decoder.cc
recognizer.cc
) )
endif() endif()
if (USING_U2) if (USING_U2)
list(APPEND decoder_src list(APPEND srcs
ctc_prefix_beam_search_decoder.cc ctc_prefix_beam_search_decoder.cc
u2_recognizer.cc
) )
endif() endif()
add_library(decoder STATIC ${decoder_src}) add_library(decoder STATIC ${srcs})
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings) target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings)
# test # test
@ -30,7 +26,6 @@ if (USING_DS2)
set(BINS set(BINS
ctc_beam_search_decoder_main ctc_beam_search_decoder_main
nnet_logprob_decoder_main nnet_logprob_decoder_main
recognizer_main
ctc_tlg_decoder_main ctc_tlg_decoder_main
) )
@ -45,7 +40,6 @@ endif()
if (USING_U2) if (USING_U2)
set(TEST_BINS set(TEST_BINS
ctc_prefix_beam_search_decoder_main ctc_prefix_beam_search_decoder_main
u2_recognizer_main
) )
foreach(bin_name IN LISTS TEST_BINS) foreach(bin_name IN LISTS TEST_BINS)

@ -21,7 +21,7 @@ using kaldi::Matrix;
using kaldi::Vector; using kaldi::Vector;
using std::vector; using std::vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet, Decodable::Decodable(const std::shared_ptr<NnetBase>& nnet,
const std::shared_ptr<FrontendInterface>& frontend, const std::shared_ptr<FrontendInterface>& frontend,
kaldi::BaseFloat acoustic_scale) kaldi::BaseFloat acoustic_scale)
: frontend_(frontend), : frontend_(frontend),

@ -24,7 +24,7 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface { class Decodable : public kaldi::DecodableInterface {
public: public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet, explicit Decodable(const std::shared_ptr<NnetBase>& nnet,
const std::shared_ptr<FrontendInterface>& frontend, const std::shared_ptr<FrontendInterface>& frontend,
kaldi::BaseFloat acoustic_scale = 1.0); kaldi::BaseFloat acoustic_scale = 1.0);
@ -63,14 +63,14 @@ class Decodable : public kaldi::DecodableInterface {
int32 TokenId2NnetId(int32 token_id); int32 TokenId2NnetId(int32 token_id);
std::shared_ptr<NnetInterface> Nnet() { return nnet_; } std::shared_ptr<NnetBase> Nnet() { return nnet_; }
// for offline test // for offline test
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood); void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood);
private: private:
std::shared_ptr<FrontendInterface> frontend_; std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_; std::shared_ptr<NnetBase> nnet_;
// nnet outputs' cache // nnet outputs' cache
kaldi::Matrix<kaldi::BaseFloat> nnet_out_cache_; kaldi::Matrix<kaldi::BaseFloat> nnet_out_cache_;

@ -48,7 +48,7 @@ class Tensor {
std::vector<T> _data; std::vector<T> _data;
}; };
class PaddleNnet : public NnetInterface { class PaddleNnet : public NnetBase {
public: public:
PaddleNnet(const ModelOptions& opts); PaddleNnet(const ModelOptions& opts);

@ -11,8 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "base/basic_types.h" #include "base/basic_types.h"
@ -105,11 +103,15 @@ class NnetInterface {
// true, nnet output is logprob; otherwise is prob, // true, nnet output is logprob; otherwise is prob,
virtual bool IsLogProb() = 0; virtual bool IsLogProb() = 0;
int SubsamplingRate() const { return subsampling_rate_; }
// using to get encoder outs. e.g. seq2seq with Attention model. // using to get encoder outs. e.g. seq2seq with Attention model.
virtual void EncoderOuts( virtual void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const = 0; std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const = 0;
};
class NnetBase : public NnetInterface {
public:
int SubsamplingRate() const { return subsampling_rate_; }
protected: protected:
int subsampling_rate_{1}; int subsampling_rate_{1};

@ -193,7 +193,7 @@ U2Nnet::U2Nnet(const U2Nnet& other) {
// ignore inner states // ignore inner states
} }
std::shared_ptr<NnetInterface> U2Nnet::Copy() const { std::shared_ptr<NnetBase> U2Nnet::Copy() const {
auto asr_model = std::make_shared<U2Nnet>(*this); auto asr_model = std::make_shared<U2Nnet>(*this);
// reset inner state for new decoding // reset inner state for new decoding
asr_model->Reset(); asr_model->Reset();

@ -24,7 +24,7 @@
namespace ppspeech { namespace ppspeech {
class U2NnetBase : public NnetInterface { class U2NnetBase : public NnetBase {
public: public:
virtual int context() const { return right_context_ + 1; } virtual int context() const { return right_context_ + 1; }
virtual int right_context() const { return right_context_; } virtual int right_context() const { return right_context_; }
@ -41,7 +41,7 @@ class U2NnetBase : public NnetInterface {
// start: false, it is the start chunk of one sentence, else true // start: false, it is the start chunk of one sentence, else true
virtual int num_frames_for_chunk(bool start) const; virtual int num_frames_for_chunk(bool start) const;
virtual std::shared_ptr<NnetInterface> Copy() const = 0; virtual std::shared_ptr<NnetBase> Copy() const = 0;
virtual void ForwardEncoderChunk( virtual void ForwardEncoderChunk(
const std::vector<kaldi::BaseFloat>& chunk_feats, const std::vector<kaldi::BaseFloat>& chunk_feats,
@ -99,7 +99,7 @@ class U2Nnet : public U2NnetBase {
std::shared_ptr<paddle::jit::Layer> model() const { return model_; } std::shared_ptr<paddle::jit::Layer> model() const { return model_; }
std::shared_ptr<NnetInterface> Copy() const override; std::shared_ptr<NnetBase> Copy() const override;
void ForwardEncoderChunkImpl( void ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats, const std::vector<kaldi::BaseFloat>& chunk_feats,

@ -2,7 +2,7 @@ add_library(websocket STATIC
websocket_server.cc websocket_server.cc
websocket_client.cc websocket_client.cc
) )
target_link_libraries(websocket PUBLIC frontend decoder nnet) target_link_libraries(websocket PUBLIC frontend nnet decoder recognizer)
add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc) add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc)
target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)

@ -19,7 +19,7 @@
#include "boost/asio/ip/tcp.hpp" #include "boost/asio/ip/tcp.hpp"
#include "boost/beast/core.hpp" #include "boost/beast/core.hpp"
#include "boost/beast/websocket.hpp" #include "boost/beast/websocket.hpp"
#include "decoder/recognizer.h" #include "recognizer/recognizer.h"
#include "frontend/audio/feature_pipeline.h" #include "frontend/audio/feature_pipeline.h"
namespace beast = boost::beast; // from <boost/beast.hpp> namespace beast = boost::beast; // from <boost/beast.hpp>

@ -0,0 +1,45 @@
set(srcs)
if (USING_DS2)
list(APPEND srcs
recognizer.cc
)
endif()
if (USING_U2)
list(APPEND srcs
u2_recognizer.cc
)
endif()
add_library(recognizer STATIC ${srcs})
target_link_libraries(recognizer PUBLIC decoder)
# test
if (USING_DS2)
set(BINS recognizer_main)
foreach(bin_name IN LISTS BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
endforeach()
endif()
if (USING_U2)
set(TEST_BINS
u2_recognizer_main
)
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endforeach()
endif()

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "decoder/recognizer.h" #include "recognizer/recognizer.h"
namespace ppspeech { namespace ppspeech {

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "decoder/param.h" #include "decoder/param.h"
#include "decoder/recognizer.h" #include "recognizer/recognizer.h"
#include "kaldi/feat/wave-reader.h" #include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
@ -22,15 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate"); DEFINE_int32(sample_rate, 16000, "sample rate");
ppspeech::RecognizerResource InitRecognizerResoure() {
ppspeech::RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts =
ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags();
return resource;
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:"); gflags::SetUsageMessage("Usage:");
@ -39,7 +30,7 @@ int main(int argc, char* argv[]) {
google::InstallFailureSignalHandler(); google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1; FLAGS_logtostderr = 1;
ppspeech::RecognizerResource resource = InitRecognizerResoure(); ppspeech::RecognizerResource resource = ppspeech::RecognizerResource::InitFromFlags();
ppspeech::Recognizer recognizer(resource); ppspeech::Recognizer recognizer(resource);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader( kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "decoder/u2_recognizer.h" #include "recognizer/u2_recognizer.h"
#include "nnet/u2_nnet.h" #include "nnet/u2_nnet.h"
@ -30,7 +30,7 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts)); feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<NnetInterface> nnet(new U2Nnet(resource.model_opts)); std::shared_ptr<NnetBase> nnet(new U2Nnet(resource.model_opts));
BaseFloat am_scale = resource.acoustic_scale; BaseFloat am_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale)); decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale));

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "decoder/param.h" #include "decoder/param.h"
#include "decoder/u2_recognizer.h" #include "recognizer/u2_recognizer.h"
#include "kaldi/feat/wave-reader.h" #include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
Loading…
Cancel
Save