[engine]fix asr compile (#3078)

* fix asr compile
* add pybind
pull/3112/head
YangZhou 2 years ago committed by GitHub
parent ab4217c2e4
commit 2be7e5725f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -72,6 +72,7 @@ include(gflags)
include(glog)
include(pybind)
# gtest
if(WITH_TESTING)
include(gtest) # download, build, install gtest
@ -109,6 +110,7 @@ if(WITH_ASR)
message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
message(STATUS "PYTHON_INCLUDE_DIR = ${PYTHON_INCLUDE_DIR}")
include_directories(${PYTHON_INCLUDE_DIR})
if(pybind11_FOUND)
message(STATUS "pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}")
@ -188,4 +190,4 @@ set(CPACK_PACKAGE_VERSION_PATCH 0)
set(CPACK_PACKAGE_DESCRIPTION "paddlespeech library")
set(CPACK_PACKAGE_CONTACT "paddlespeech@baidu.com")
set(CPACK_SOURCE_GENERATOR "TGZ")
include (CPack)
include (CPack)

@ -0,0 +1,42 @@
#the pybind11 is from:https://github.com/pybind/pybind11
# Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>, All rights reserved.
SET(PYBIND_ZIP "v2.10.0.zip")
SET(LOCAL_PYBIND_ZIP ${FETCHCONTENT_BASE_DIR}/${PYBIND_ZIP})
SET(PYBIND_SRC ${FETCHCONTENT_BASE_DIR}/pybind11)
SET(DOWNLOAD_URL "https://paddleaudio.bj.bcebos.com/build/v2.10.0.zip")
SET(PYBIND_TIMEOUT 600 CACHE STRING "Timeout in seconds when downloading pybind.")
IF(NOT EXISTS ${LOCAL_PYBIND_ZIP})
FILE(DOWNLOAD ${DOWNLOAD_URL}
${LOCAL_PYBIND_ZIP}
TIMEOUT ${PYBIND_TIMEOUT}
STATUS ERR
SHOW_PROGRESS
)
IF(ERR EQUAL 0)
MESSAGE(STATUS "download pybind success")
ELSE()
MESSAGE(FATAL_ERROR "download pybind fail")
ENDIF()
ENDIF()
IF(NOT EXISTS ${PYBIND_SRC})
EXECUTE_PROCESS(
COMMAND ${CMAKE_COMMAND} -E tar xfz ${LOCAL_PYBIND_ZIP}
WORKING_DIRECTORY ${FETCHCONTENT_BASE_DIR}
RESULT_VARIABLE tar_result
)
file(RENAME ${FETCHCONTENT_BASE_DIR}/pybind11-2.10.0 ${PYBIND_SRC})
IF (tar_result MATCHES 0)
MESSAGE(STATUS "unzip pybind success")
ELSE()
MESSAGE(FATAL_ERROR "unzip pybind fail")
ENDIF()
ENDIF()
include_directories(${PYBIND_SRC}/include)

@ -101,10 +101,10 @@ void U2Nnet::Warmup() {
auto encoder_out = paddle::ones(
{1, 20, 512}, paddle::DataType::FLOAT32, phi::CPUPlace());
std::vector<paddle::experimental::Tensor> inputs{
std::vector<paddle::Tensor> inputs{
hyps, hyps_lens, encoder_out};
std::vector<paddle::experimental::Tensor> outputs =
std::vector<paddle::Tensor> outputs =
forward_attention_decoder_(inputs);
}
@ -523,7 +523,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
}
#endif // end TEST_DEBUG
std::vector<paddle::experimental::Tensor> inputs{
std::vector<paddle::Tensor> inputs{
hyps_tensor, hyps_lens, encoder_out};
std::vector<paddle::Tensor> outputs = forward_attention_decoder_(inputs);
CHECK_EQ(outputs.size(), 2);
@ -594,7 +594,7 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
} else {
// dump r_probs
CHECK_EQ(r_probs_shape.size(), 1);
CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0];
//CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0];
}
// compute rescoring score
@ -604,15 +604,15 @@ void U2Nnet::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
VLOG(2) << "split prob: " << probs_v.size() << " "
<< probs_v[0].shape().size() << " 0: " << probs_v[0].shape()[0]
<< ", " << probs_v[0].shape()[1] << ", " << probs_v[0].shape()[2];
CHECK(static_cast<int>(probs_v.size()) == num_hyps)
<< ": is " << probs_v.size() << " expect: " << num_hyps;
//CHECK(static_cast<int>(probs_v.size()) == num_hyps)
// << ": is " << probs_v.size() << " expect: " << num_hyps;
std::vector<paddle::Tensor> r_probs_v;
if (is_bidecoder_ && reverse_weight > 0) {
r_probs_v = paddle::experimental::split_with_num(r_probs, num_hyps, 0);
CHECK(static_cast<int>(r_probs_v.size()) == num_hyps)
<< "r_probs_v size: is " << r_probs_v.size()
<< " expect: " << num_hyps;
//CHECK(static_cast<int>(r_probs_v.size()) == num_hyps)
// << "r_probs_v size: is " << r_probs_v.size()
// << " expect: " << num_hyps;
}
for (int i = 0; i < num_hyps; ++i) {
@ -654,7 +654,7 @@ void U2Nnet::EncoderOuts(
const int& B = shape[0];
const int& T = shape[1];
const int& D = shape[2];
CHECK(B == 1) << "Only support batch one.";
//CHECK(B == 1) << "Only support batch one.";
VLOG(3) << "encoder out " << i << " shape: (" << B << "," << T << ","
<< D << ")";
@ -668,4 +668,4 @@ void U2Nnet::EncoderOuts(
}
}
} // namespace ppspeech
} // namespace ppspeech

Loading…
Cancel
Save