diff --git a/demos/TTSArmLinux/src/Predictor.hpp b/demos/TTSArmLinux/src/Predictor.hpp index 8c4f4655d..765d859be 100644 --- a/demos/TTSArmLinux/src/Predictor.hpp +++ b/demos/TTSArmLinux/src/Predictor.hpp @@ -10,23 +10,17 @@ using namespace paddle::lite_api; class Predictor { -private: - float inferenceTime = 0; - std::shared_ptr AMPredictor = nullptr; - std::shared_ptr VOCPredictor = nullptr; - std::vector wav; - public: - bool init(const std::string &AMModelPath, const std::string &VOCModelPath, int cpuThreadNum, const std::string &cpuPowerMode) { + bool Init(const std::string &AMModelPath, const std::string &VOCModelPath, int cpuThreadNum, const std::string &cpuPowerMode) { // Release model if exists - releaseModel(); + ReleaseModel(); - AMPredictor = loadModel(AMModelPath, cpuThreadNum, cpuPowerMode); - if (AMPredictor == nullptr) { + AM_predictor_ = LoadModel(AMModelPath, cpuThreadNum, cpuPowerMode); + if (AM_predictor_ == nullptr) { return false; } - VOCPredictor = loadModel(VOCModelPath, cpuThreadNum, cpuPowerMode); - if (VOCPredictor == nullptr) { + VOC_predictor_ = LoadModel(VOCModelPath, cpuThreadNum, cpuPowerMode); + if (VOC_predictor_ == nullptr) { return false; } @@ -34,11 +28,11 @@ public: } ~Predictor() { - releaseModel(); - releaseWav(); + ReleaseModel(); + ReleaseWav(); } - std::shared_ptr loadModel(const std::string &modelPath, int cpuThreadNum, const std::string &cpuPowerMode) { + std::shared_ptr LoadModel(const std::string &modelPath, int cpuThreadNum, const std::string &cpuPowerMode) { if (modelPath.empty()) { return nullptr; } @@ -68,13 +62,13 @@ public: return CreatePaddlePredictor(config); } - void releaseModel() { - AMPredictor = nullptr; - VOCPredictor = nullptr; + void ReleaseModel() { + AM_predictor_ = nullptr; + VOC_predictor_ = nullptr; } - bool runModel(const std::vector &phones) { - if (!isLoaded()) { + bool RunModel(const std::vector &phones) { + if (!IsLoaded()) { return false; } @@ -82,26 +76,26 @@ public: auto start = std::chrono::system_clock::now(); // 执行推理 - VOCOutputToWav(getAMOutput(phones)); + VOCOutputToWav(GetAMOutput(phones)); // 计时结束 auto end = std::chrono::system_clock::now(); // 计算用时 std::chrono::duration duration = end - start; - inferenceTime = duration.count() * 1000; // 单位:毫秒 + inference_time_ = duration.count() * 1000; // 单位:毫秒 return true; } - std::unique_ptr getAMOutput(const std::vector &phones) { - auto phones_handle = AMPredictor->GetInput(0); + std::unique_ptr GetAMOutput(const std::vector &phones) { + auto phones_handle = AM_predictor_->GetInput(0); phones_handle->Resize({static_cast(phones.size())}); phones_handle->CopyFromCpu(phones.data()); - AMPredictor->Run(); + AM_predictor_->Run(); // 获取输出Tensor - auto am_output_handle = AMPredictor->GetOutput(0); + auto am_output_handle = AM_predictor_->GetOutput(0); // 打印输出Tensor的shape std::cout << "AM Output shape: "; auto shape = am_output_handle->shape(); @@ -116,16 +110,16 @@ public: } void VOCOutputToWav(std::unique_ptr &&input) { - auto mel_handle = VOCPredictor->GetInput(0); + auto mel_handle = VOC_predictor_->GetInput(0); // [?, 80] auto dims = input->shape(); mel_handle->Resize(dims); auto am_output_data = input->mutable_data(); mel_handle->CopyFromCpu(am_output_data); - VOCPredictor->Run(); + VOC_predictor_->Run(); // 获取输出Tensor - auto voc_output_handle = VOCPredictor->GetOutput(0); + auto voc_output_handle = VOC_predictor_->GetOutput(0); // 打印输出Tensor的shape std::cout << "VOC Output shape: "; auto shape = voc_output_handle->shape(); @@ -139,25 +133,25 @@ public: for (auto dim : voc_output_handle->shape()) { output_size *= dim; } - wav.resize(output_size); + wav_.resize(output_size); auto output_data = voc_output_handle->mutable_data(); - std::copy_n(output_data, output_size, wav.data()); + std::copy_n(output_data, output_size, wav_.data()); } - bool isLoaded() { - return AMPredictor != nullptr && VOCPredictor != nullptr; + bool IsLoaded() { + return AM_predictor_ != nullptr && VOC_predictor_ != nullptr; } - float getInferenceTime() { - return inferenceTime; + float GetInferenceTime() { + return inference_time_; } - const std::vector & getWav() { - return wav; + const std::vector & GetWav() { + return wav_; } - void releaseWav() { - wav.clear(); + void ReleaseWav() { + wav_.clear(); } struct WavHeader { @@ -185,7 +179,7 @@ public: uint32_t data_size = 0; }; - bool writeWavToFile(const std::string &wavPath) { + bool WriteWavToFile(const std::string &wavPath) { std::ofstream fout(wavPath, std::ios::binary); if (!fout.is_open()) { return false; @@ -194,15 +188,21 @@ public: // 写入头信息 WavHeader header; header.size = sizeof(header) - 8; - header.data_size = wav.size() * sizeof(float); + header.data_size = wav_.size() * sizeof(float); header.byte_rate = header.sample_rate * header.num_channels * header.bits_per_sample / 8; header.block_align = header.num_channels * header.bits_per_sample / 8; fout.write(reinterpret_cast(&header), sizeof(header)); // 写入wav数据 - fout.write(reinterpret_cast(wav.data()), header.data_size); + fout.write(reinterpret_cast(wav_.data()), header.data_size); fout.close(); return true; } + +private: + float inference_time_ = 0; + std::shared_ptr AM_predictor_ = nullptr; + std::shared_ptr VOC_predictor_ = nullptr; + std::vector wav_; }; diff --git a/demos/TTSArmLinux/src/main.cc b/demos/TTSArmLinux/src/main.cc index 64aeaa857..4068cb0b9 100644 --- a/demos/TTSArmLinux/src/main.cc +++ b/demos/TTSArmLinux/src/main.cc @@ -50,19 +50,19 @@ int main(int argc, char *argv[]) { } Predictor predictor; - if (!predictor.init(AMModelPath, VOCModelPath, 1, "LITE_POWER_HIGH")) { + if (!predictor.Init(AMModelPath, VOCModelPath, 1, "LITE_POWER_HIGH")) { std::cerr << "predictor init failed" << std::endl; return -1; } - if (!predictor.runModel(sentencesToChoose[sentencesIndex])) { + if (!predictor.RunModel(sentencesToChoose[sentencesIndex])) { std::cerr << "predictor run model failed" << std::endl; return -1; } - std::cout << "Inference time: " << predictor.getInferenceTime() << "ms, WAV size: " << predictor.getWav().size() << std::endl; + std::cout << "Inference time: " << predictor.GetInferenceTime() << "ms, WAV size: " << predictor.GetWav().size() << std::endl; - if (!predictor.writeWavToFile(outputWavPath)) { + if (!predictor.WriteWavToFile(outputWavPath)) { std::cerr << "write wav file failed" << std::endl; return -1; }