seprate recognizer; NnetBase as base class

pull/2524/head
Hui Zhang 2 years ago
parent fddcd36fa0
commit 99b3632d4d

@ -32,6 +32,12 @@ ${CMAKE_CURRENT_SOURCE_DIR}/decoder
)
add_subdirectory(decoder)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/recognizer
)
add_subdirectory(recognizer)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/protocol

@ -1,28 +1,24 @@
project(decoder)
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
set(decoder_src )
set(srcs)
if (USING_DS2)
list(APPEND decoder_src
list(APPEND srcs
ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
ctc_beam_search_decoder.cc
ctc_tlg_decoder.cc
recognizer.cc
)
endif()
if (USING_U2)
list(APPEND decoder_src
list(APPEND srcs
ctc_prefix_beam_search_decoder.cc
u2_recognizer.cc
)
endif()
add_library(decoder STATIC ${decoder_src})
add_library(decoder STATIC ${srcs})
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings)
# test
@ -30,7 +26,6 @@ if (USING_DS2)
set(BINS
ctc_beam_search_decoder_main
nnet_logprob_decoder_main
recognizer_main
ctc_tlg_decoder_main
)
@ -45,7 +40,6 @@ endif()
if (USING_U2)
set(TEST_BINS
ctc_prefix_beam_search_decoder_main
u2_recognizer_main
)
foreach(bin_name IN LISTS TEST_BINS)

@ -21,7 +21,7 @@ using kaldi::Matrix;
using kaldi::Vector;
using std::vector;
Decodable::Decodable(const std::shared_ptr<NnetInterface>& nnet,
Decodable::Decodable(const std::shared_ptr<NnetBase>& nnet,
const std::shared_ptr<FrontendInterface>& frontend,
kaldi::BaseFloat acoustic_scale)
: frontend_(frontend),

@ -24,7 +24,7 @@ struct DecodableOpts;
class Decodable : public kaldi::DecodableInterface {
public:
explicit Decodable(const std::shared_ptr<NnetInterface>& nnet,
explicit Decodable(const std::shared_ptr<NnetBase>& nnet,
const std::shared_ptr<FrontendInterface>& frontend,
kaldi::BaseFloat acoustic_scale = 1.0);
@ -63,14 +63,14 @@ class Decodable : public kaldi::DecodableInterface {
int32 TokenId2NnetId(int32 token_id);
std::shared_ptr<NnetInterface> Nnet() { return nnet_; }
std::shared_ptr<NnetBase> Nnet() { return nnet_; }
// for offline test
void Acceptlikelihood(const kaldi::Matrix<kaldi::BaseFloat>& likelihood);
private:
std::shared_ptr<FrontendInterface> frontend_;
std::shared_ptr<NnetInterface> nnet_;
std::shared_ptr<NnetBase> nnet_;
// nnet outputs' cache
kaldi::Matrix<kaldi::BaseFloat> nnet_out_cache_;

@ -48,7 +48,7 @@ class Tensor {
std::vector<T> _data;
};
class PaddleNnet : public NnetInterface {
class PaddleNnet : public NnetBase {
public:
PaddleNnet(const ModelOptions& opts);

@ -11,8 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/basic_types.h"
@ -105,11 +103,15 @@ class NnetInterface {
// true, nnet output is logprob; otherwise is prob,
virtual bool IsLogProb() = 0;
int SubsamplingRate() const { return subsampling_rate_; }
// using to get encoder outs. e.g. seq2seq with Attention model.
virtual void EncoderOuts(
std::vector<kaldi::Vector<kaldi::BaseFloat>>* encoder_out) const = 0;
};
class NnetBase : public NnetInterface {
public:
int SubsamplingRate() const { return subsampling_rate_; }
protected:
int subsampling_rate_{1};

@ -193,7 +193,7 @@ U2Nnet::U2Nnet(const U2Nnet& other) {
// ignore inner states
}
std::shared_ptr<NnetInterface> U2Nnet::Copy() const {
std::shared_ptr<NnetBase> U2Nnet::Copy() const {
auto asr_model = std::make_shared<U2Nnet>(*this);
// reset inner state for new decoding
asr_model->Reset();

@ -24,7 +24,7 @@
namespace ppspeech {
class U2NnetBase : public NnetInterface {
class U2NnetBase : public NnetBase {
public:
virtual int context() const { return right_context_ + 1; }
virtual int right_context() const { return right_context_; }
@ -41,7 +41,7 @@ class U2NnetBase : public NnetInterface {
// start: false, it is the start chunk of one sentence, else true
virtual int num_frames_for_chunk(bool start) const;
virtual std::shared_ptr<NnetInterface> Copy() const = 0;
virtual std::shared_ptr<NnetBase> Copy() const = 0;
virtual void ForwardEncoderChunk(
const std::vector<kaldi::BaseFloat>& chunk_feats,
@ -99,7 +99,7 @@ class U2Nnet : public U2NnetBase {
std::shared_ptr<paddle::jit::Layer> model() const { return model_; }
std::shared_ptr<NnetInterface> Copy() const override;
std::shared_ptr<NnetBase> Copy() const override;
void ForwardEncoderChunkImpl(
const std::vector<kaldi::BaseFloat>& chunk_feats,

@ -2,7 +2,7 @@ add_library(websocket STATIC
websocket_server.cc
websocket_client.cc
)
target_link_libraries(websocket PUBLIC frontend decoder nnet)
target_link_libraries(websocket PUBLIC frontend nnet decoder recognizer)
add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc)
target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)

@ -19,7 +19,7 @@
#include "boost/asio/ip/tcp.hpp"
#include "boost/beast/core.hpp"
#include "boost/beast/websocket.hpp"
#include "decoder/recognizer.h"
#include "recognizer/recognizer.h"
#include "frontend/audio/feature_pipeline.h"
namespace beast = boost::beast; // from <boost/beast.hpp>

@ -0,0 +1,45 @@
set(srcs)
if (USING_DS2)
list(APPEND srcs
recognizer.cc
)
endif()
if (USING_U2)
list(APPEND srcs
u2_recognizer.cc
)
endif()
add_library(recognizer STATIC ${srcs})
target_link_libraries(recognizer PUBLIC decoder)
# test
if (USING_DS2)
set(BINS recognizer_main)
foreach(bin_name IN LISTS BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
endforeach()
endif()
if (USING_U2)
set(TEST_BINS
u2_recognizer_main
)
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endforeach()
endif()

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/recognizer.h"
#include "recognizer/recognizer.h"
namespace ppspeech {

@ -13,7 +13,7 @@
// limitations under the License.
#include "decoder/param.h"
#include "decoder/recognizer.h"
#include "recognizer/recognizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h"
@ -22,15 +22,6 @@ DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(sample_rate, 16000, "sample rate");
ppspeech::RecognizerResource InitRecognizerResoure() {
ppspeech::RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts =
ppspeech::FeaturePipelineOptions::InitFromFlags();
resource.model_opts = ppspeech::ModelOptions::InitFromFlags();
resource.tlg_opts = ppspeech::TLGDecoderOptions::InitFromFlags();
return resource;
}
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
@ -39,7 +30,7 @@ int main(int argc, char* argv[]) {
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
ppspeech::RecognizerResource resource = InitRecognizerResoure();
ppspeech::RecognizerResource resource = ppspeech::RecognizerResource::InitFromFlags();
ppspeech::Recognizer recognizer(resource);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/u2_recognizer.h"
#include "recognizer/u2_recognizer.h"
#include "nnet/u2_nnet.h"
@ -30,7 +30,7 @@ U2Recognizer::U2Recognizer(const U2RecognizerResource& resource)
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<NnetInterface> nnet(new U2Nnet(resource.model_opts));
std::shared_ptr<NnetBase> nnet(new U2Nnet(resource.model_opts));
BaseFloat am_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale));

@ -13,7 +13,7 @@
// limitations under the License.
#include "decoder/param.h"
#include "decoder/u2_recognizer.h"
#include "recognizer/u2_recognizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h"
Loading…
Cancel
Save