demos/TTSArmLinux: adjust the code style to match the project

pull/2991/head
彭逸豪 3 years ago
parent 0f62ccc0d2
commit 447a68eaab

@ -10,23 +10,17 @@
using namespace paddle::lite_api; using namespace paddle::lite_api;
class Predictor { class Predictor {
private:
float inferenceTime = 0;
std::shared_ptr<PaddlePredictor> AMPredictor = nullptr;
std::shared_ptr<PaddlePredictor> VOCPredictor = nullptr;
std::vector<float> wav;
public: public:
bool init(const std::string &AMModelPath, const std::string &VOCModelPath, int cpuThreadNum, const std::string &cpuPowerMode) { bool Init(const std::string &AMModelPath, const std::string &VOCModelPath, int cpuThreadNum, const std::string &cpuPowerMode) {
// Release model if exists // Release model if exists
releaseModel(); ReleaseModel();
AMPredictor = loadModel(AMModelPath, cpuThreadNum, cpuPowerMode); AM_predictor_ = LoadModel(AMModelPath, cpuThreadNum, cpuPowerMode);
if (AMPredictor == nullptr) { if (AM_predictor_ == nullptr) {
return false; return false;
} }
VOCPredictor = loadModel(VOCModelPath, cpuThreadNum, cpuPowerMode); VOC_predictor_ = LoadModel(VOCModelPath, cpuThreadNum, cpuPowerMode);
if (VOCPredictor == nullptr) { if (VOC_predictor_ == nullptr) {
return false; return false;
} }
@ -34,11 +28,11 @@ public:
} }
~Predictor() { ~Predictor() {
releaseModel(); ReleaseModel();
releaseWav(); ReleaseWav();
} }
std::shared_ptr<PaddlePredictor> loadModel(const std::string &modelPath, int cpuThreadNum, const std::string &cpuPowerMode) { std::shared_ptr<PaddlePredictor> LoadModel(const std::string &modelPath, int cpuThreadNum, const std::string &cpuPowerMode) {
if (modelPath.empty()) { if (modelPath.empty()) {
return nullptr; return nullptr;
} }
@ -68,13 +62,13 @@ public:
return CreatePaddlePredictor<MobileConfig>(config); return CreatePaddlePredictor<MobileConfig>(config);
} }
void releaseModel() { void ReleaseModel() {
AMPredictor = nullptr; AM_predictor_ = nullptr;
VOCPredictor = nullptr; VOC_predictor_ = nullptr;
} }
bool runModel(const std::vector<float> &phones) { bool RunModel(const std::vector<float> &phones) {
if (!isLoaded()) { if (!IsLoaded()) {
return false; return false;
} }
@ -82,26 +76,26 @@ public:
auto start = std::chrono::system_clock::now(); auto start = std::chrono::system_clock::now();
// 执行推理 // 执行推理
VOCOutputToWav(getAMOutput(phones)); VOCOutputToWav(GetAMOutput(phones));
// 计时结束 // 计时结束
auto end = std::chrono::system_clock::now(); auto end = std::chrono::system_clock::now();
// 计算用时 // 计算用时
std::chrono::duration<float> duration = end - start; std::chrono::duration<float> duration = end - start;
inferenceTime = duration.count() * 1000; // 单位:毫秒 inference_time_ = duration.count() * 1000; // 单位:毫秒
return true; return true;
} }
std::unique_ptr<const Tensor> getAMOutput(const std::vector<float> &phones) { std::unique_ptr<const Tensor> GetAMOutput(const std::vector<float> &phones) {
auto phones_handle = AMPredictor->GetInput(0); auto phones_handle = AM_predictor_->GetInput(0);
phones_handle->Resize({static_cast<int64_t>(phones.size())}); phones_handle->Resize({static_cast<int64_t>(phones.size())});
phones_handle->CopyFromCpu(phones.data()); phones_handle->CopyFromCpu(phones.data());
AMPredictor->Run(); AM_predictor_->Run();
// 获取输出Tensor // 获取输出Tensor
auto am_output_handle = AMPredictor->GetOutput(0); auto am_output_handle = AM_predictor_->GetOutput(0);
// 打印输出Tensor的shape // 打印输出Tensor的shape
std::cout << "AM Output shape: "; std::cout << "AM Output shape: ";
auto shape = am_output_handle->shape(); auto shape = am_output_handle->shape();
@ -116,16 +110,16 @@ public:
} }
void VOCOutputToWav(std::unique_ptr<const Tensor> &&input) { void VOCOutputToWav(std::unique_ptr<const Tensor> &&input) {
auto mel_handle = VOCPredictor->GetInput(0); auto mel_handle = VOC_predictor_->GetInput(0);
// [?, 80] // [?, 80]
auto dims = input->shape(); auto dims = input->shape();
mel_handle->Resize(dims); mel_handle->Resize(dims);
auto am_output_data = input->mutable_data<float>(); auto am_output_data = input->mutable_data<float>();
mel_handle->CopyFromCpu(am_output_data); mel_handle->CopyFromCpu(am_output_data);
VOCPredictor->Run(); VOC_predictor_->Run();
// 获取输出Tensor // 获取输出Tensor
auto voc_output_handle = VOCPredictor->GetOutput(0); auto voc_output_handle = VOC_predictor_->GetOutput(0);
// 打印输出Tensor的shape // 打印输出Tensor的shape
std::cout << "VOC Output shape: "; std::cout << "VOC Output shape: ";
auto shape = voc_output_handle->shape(); auto shape = voc_output_handle->shape();
@ -139,25 +133,25 @@ public:
for (auto dim : voc_output_handle->shape()) { for (auto dim : voc_output_handle->shape()) {
output_size *= dim; output_size *= dim;
} }
wav.resize(output_size); wav_.resize(output_size);
auto output_data = voc_output_handle->mutable_data<float>(); auto output_data = voc_output_handle->mutable_data<float>();
std::copy_n(output_data, output_size, wav.data()); std::copy_n(output_data, output_size, wav_.data());
} }
bool isLoaded() { bool IsLoaded() {
return AMPredictor != nullptr && VOCPredictor != nullptr; return AM_predictor_ != nullptr && VOC_predictor_ != nullptr;
} }
float getInferenceTime() { float GetInferenceTime() {
return inferenceTime; return inference_time_;
} }
const std::vector<float> & getWav() { const std::vector<float> & GetWav() {
return wav; return wav_;
} }
void releaseWav() { void ReleaseWav() {
wav.clear(); wav_.clear();
} }
struct WavHeader { struct WavHeader {
@ -185,7 +179,7 @@ public:
uint32_t data_size = 0; uint32_t data_size = 0;
}; };
bool writeWavToFile(const std::string &wavPath) { bool WriteWavToFile(const std::string &wavPath) {
std::ofstream fout(wavPath, std::ios::binary); std::ofstream fout(wavPath, std::ios::binary);
if (!fout.is_open()) { if (!fout.is_open()) {
return false; return false;
@ -194,15 +188,21 @@ public:
// 写入头信息 // 写入头信息
WavHeader header; WavHeader header;
header.size = sizeof(header) - 8; header.size = sizeof(header) - 8;
header.data_size = wav.size() * sizeof(float); header.data_size = wav_.size() * sizeof(float);
header.byte_rate = header.sample_rate * header.num_channels * header.bits_per_sample / 8; header.byte_rate = header.sample_rate * header.num_channels * header.bits_per_sample / 8;
header.block_align = header.num_channels * header.bits_per_sample / 8; header.block_align = header.num_channels * header.bits_per_sample / 8;
fout.write(reinterpret_cast<const char*>(&header), sizeof(header)); fout.write(reinterpret_cast<const char*>(&header), sizeof(header));
// 写入wav数据 // 写入wav数据
fout.write(reinterpret_cast<const char*>(wav.data()), header.data_size); fout.write(reinterpret_cast<const char*>(wav_.data()), header.data_size);
fout.close(); fout.close();
return true; return true;
} }
private:
float inference_time_ = 0;
std::shared_ptr<PaddlePredictor> AM_predictor_ = nullptr;
std::shared_ptr<PaddlePredictor> VOC_predictor_ = nullptr;
std::vector<float> wav_;
}; };

@ -50,19 +50,19 @@ int main(int argc, char *argv[]) {
} }
Predictor predictor; Predictor predictor;
if (!predictor.init(AMModelPath, VOCModelPath, 1, "LITE_POWER_HIGH")) { if (!predictor.Init(AMModelPath, VOCModelPath, 1, "LITE_POWER_HIGH")) {
std::cerr << "predictor init failed" << std::endl; std::cerr << "predictor init failed" << std::endl;
return -1; return -1;
} }
if (!predictor.runModel(sentencesToChoose[sentencesIndex])) { if (!predictor.RunModel(sentencesToChoose[sentencesIndex])) {
std::cerr << "predictor run model failed" << std::endl; std::cerr << "predictor run model failed" << std::endl;
return -1; return -1;
} }
std::cout << "Inference time: " << predictor.getInferenceTime() << "ms, WAV size: " << predictor.getWav().size() << std::endl; std::cout << "Inference time: " << predictor.GetInferenceTime() << "ms, WAV size: " << predictor.GetWav().size() << std::endl;
if (!predictor.writeWavToFile(outputWavPath)) { if (!predictor.WriteWavToFile(outputWavPath)) {
std::cerr << "write wav file failed" << std::endl; std::cerr << "write wav file failed" << std::endl;
return -1; return -1;
} }

Loading…
Cancel
Save