demos/TTSArmLinux: adjust the code style to match the project

pull/2991/head
彭逸豪 3 years ago
parent 0f62ccc0d2
commit 447a68eaab

@ -10,23 +10,17 @@
using namespace paddle::lite_api;
class Predictor {
private:
float inferenceTime = 0;
std::shared_ptr<PaddlePredictor> AMPredictor = nullptr;
std::shared_ptr<PaddlePredictor> VOCPredictor = nullptr;
std::vector<float> wav;
public:
bool init(const std::string &AMModelPath, const std::string &VOCModelPath, int cpuThreadNum, const std::string &cpuPowerMode) {
bool Init(const std::string &AMModelPath, const std::string &VOCModelPath, int cpuThreadNum, const std::string &cpuPowerMode) {
// Release model if exists
releaseModel();
ReleaseModel();
AMPredictor = loadModel(AMModelPath, cpuThreadNum, cpuPowerMode);
if (AMPredictor == nullptr) {
AM_predictor_ = LoadModel(AMModelPath, cpuThreadNum, cpuPowerMode);
if (AM_predictor_ == nullptr) {
return false;
}
VOCPredictor = loadModel(VOCModelPath, cpuThreadNum, cpuPowerMode);
if (VOCPredictor == nullptr) {
VOC_predictor_ = LoadModel(VOCModelPath, cpuThreadNum, cpuPowerMode);
if (VOC_predictor_ == nullptr) {
return false;
}
@ -34,11 +28,11 @@ public:
}
~Predictor() {
releaseModel();
releaseWav();
ReleaseModel();
ReleaseWav();
}
std::shared_ptr<PaddlePredictor> loadModel(const std::string &modelPath, int cpuThreadNum, const std::string &cpuPowerMode) {
std::shared_ptr<PaddlePredictor> LoadModel(const std::string &modelPath, int cpuThreadNum, const std::string &cpuPowerMode) {
if (modelPath.empty()) {
return nullptr;
}
@ -68,13 +62,13 @@ public:
return CreatePaddlePredictor<MobileConfig>(config);
}
void releaseModel() {
AMPredictor = nullptr;
VOCPredictor = nullptr;
void ReleaseModel() {
AM_predictor_ = nullptr;
VOC_predictor_ = nullptr;
}
bool runModel(const std::vector<float> &phones) {
if (!isLoaded()) {
bool RunModel(const std::vector<float> &phones) {
if (!IsLoaded()) {
return false;
}
@ -82,26 +76,26 @@ public:
auto start = std::chrono::system_clock::now();
// 执行推理
VOCOutputToWav(getAMOutput(phones));
VOCOutputToWav(GetAMOutput(phones));
// 计时结束
auto end = std::chrono::system_clock::now();
// 计算用时
std::chrono::duration<float> duration = end - start;
inferenceTime = duration.count() * 1000; // 单位:毫秒
inference_time_ = duration.count() * 1000; // 单位:毫秒
return true;
}
std::unique_ptr<const Tensor> getAMOutput(const std::vector<float> &phones) {
auto phones_handle = AMPredictor->GetInput(0);
std::unique_ptr<const Tensor> GetAMOutput(const std::vector<float> &phones) {
auto phones_handle = AM_predictor_->GetInput(0);
phones_handle->Resize({static_cast<int64_t>(phones.size())});
phones_handle->CopyFromCpu(phones.data());
AMPredictor->Run();
AM_predictor_->Run();
// 获取输出Tensor
auto am_output_handle = AMPredictor->GetOutput(0);
auto am_output_handle = AM_predictor_->GetOutput(0);
// 打印输出Tensor的shape
std::cout << "AM Output shape: ";
auto shape = am_output_handle->shape();
@ -116,16 +110,16 @@ public:
}
void VOCOutputToWav(std::unique_ptr<const Tensor> &&input) {
auto mel_handle = VOCPredictor->GetInput(0);
auto mel_handle = VOC_predictor_->GetInput(0);
// [?, 80]
auto dims = input->shape();
mel_handle->Resize(dims);
auto am_output_data = input->mutable_data<float>();
mel_handle->CopyFromCpu(am_output_data);
VOCPredictor->Run();
VOC_predictor_->Run();
// 获取输出Tensor
auto voc_output_handle = VOCPredictor->GetOutput(0);
auto voc_output_handle = VOC_predictor_->GetOutput(0);
// 打印输出Tensor的shape
std::cout << "VOC Output shape: ";
auto shape = voc_output_handle->shape();
@ -139,25 +133,25 @@ public:
for (auto dim : voc_output_handle->shape()) {
output_size *= dim;
}
wav.resize(output_size);
wav_.resize(output_size);
auto output_data = voc_output_handle->mutable_data<float>();
std::copy_n(output_data, output_size, wav.data());
std::copy_n(output_data, output_size, wav_.data());
}
bool isLoaded() {
return AMPredictor != nullptr && VOCPredictor != nullptr;
bool IsLoaded() {
return AM_predictor_ != nullptr && VOC_predictor_ != nullptr;
}
float getInferenceTime() {
return inferenceTime;
float GetInferenceTime() {
return inference_time_;
}
const std::vector<float> & getWav() {
return wav;
const std::vector<float> & GetWav() {
return wav_;
}
void releaseWav() {
wav.clear();
void ReleaseWav() {
wav_.clear();
}
struct WavHeader {
@ -185,7 +179,7 @@ public:
uint32_t data_size = 0;
};
bool writeWavToFile(const std::string &wavPath) {
bool WriteWavToFile(const std::string &wavPath) {
std::ofstream fout(wavPath, std::ios::binary);
if (!fout.is_open()) {
return false;
@ -194,15 +188,21 @@ public:
// 写入头信息
WavHeader header;
header.size = sizeof(header) - 8;
header.data_size = wav.size() * sizeof(float);
header.data_size = wav_.size() * sizeof(float);
header.byte_rate = header.sample_rate * header.num_channels * header.bits_per_sample / 8;
header.block_align = header.num_channels * header.bits_per_sample / 8;
fout.write(reinterpret_cast<const char*>(&header), sizeof(header));
// 写入wav数据
fout.write(reinterpret_cast<const char*>(wav.data()), header.data_size);
fout.write(reinterpret_cast<const char*>(wav_.data()), header.data_size);
fout.close();
return true;
}
private:
float inference_time_ = 0;
std::shared_ptr<PaddlePredictor> AM_predictor_ = nullptr;
std::shared_ptr<PaddlePredictor> VOC_predictor_ = nullptr;
std::vector<float> wav_;
};

@ -50,19 +50,19 @@ int main(int argc, char *argv[]) {
}
Predictor predictor;
if (!predictor.init(AMModelPath, VOCModelPath, 1, "LITE_POWER_HIGH")) {
if (!predictor.Init(AMModelPath, VOCModelPath, 1, "LITE_POWER_HIGH")) {
std::cerr << "predictor init failed" << std::endl;
return -1;
}
if (!predictor.runModel(sentencesToChoose[sentencesIndex])) {
if (!predictor.RunModel(sentencesToChoose[sentencesIndex])) {
std::cerr << "predictor run model failed" << std::endl;
return -1;
}
std::cout << "Inference time: " << predictor.getInferenceTime() << "ms, WAV size: " << predictor.getWav().size() << std::endl;
std::cout << "Inference time: " << predictor.GetInferenceTime() << "ms, WAV size: " << predictor.GetWav().size() << std::endl;
if (!predictor.writeWavToFile(outputWavPath)) {
if (!predictor.WriteWavToFile(outputWavPath)) {
std::cerr << "write wav file failed" << std::endl;
return -1;
}

Loading…
Cancel
Save