From e5fb491f1dd1fbe65e82f9d058859ca3ec06c3f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BD=AD=E9=80=B8=E8=B1=AA?= Date: Wed, 8 Mar 2023 06:02:29 +0000 Subject: [PATCH] refac: [demos/TTSArmLinux] support 16-bit PCM and 32-bit IEEE float WAV output with class template --- demos/TTSArmLinux/src/Predictor.hpp | 175 ++++++++++++++++------------ demos/TTSArmLinux/src/main.cc | 20 +++- 2 files changed, 116 insertions(+), 79 deletions(-) diff --git a/demos/TTSArmLinux/src/Predictor.hpp b/demos/TTSArmLinux/src/Predictor.hpp index 87e883f5e..5c59d417f 100644 --- a/demos/TTSArmLinux/src/Predictor.hpp +++ b/demos/TTSArmLinux/src/Predictor.hpp @@ -9,18 +9,51 @@ using namespace paddle::lite_api; -// WAV采样率 -// 如果播放速度和音调异常,请修改采样率 -// 常见采样率:16000, 24000, 32000, 44100, 48000, 96000 -#define WAV_SAMPLE_RATE 24000 - -// WAV数据类型 -// 定义在此处以便在 int16_t 和 float 之间切换 -typedef int16_t WavDataType; - +// WavDataType: WAV数据类型 +// 可在 int16_t 和 float 之间切换, +// 用于生成 16-bit PCM 或 32-bit IEEE float 格式的 WAV +template class Predictor { public: - bool Init(const std::string &AMModelPath, const std::string &VOCModelPath, int cpuThreadNum, const std::string &cpuPowerMode) { + struct WavHeader { + // RIFF 头 + char riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t size = 0; + char wave[4] = {'W', 'A', 'V', 'E'}; + + // FMT 头 + char fmt[4] = {'f', 'm', 't', ' '}; + uint32_t fmt_size = 16; + uint16_t audio_format = 0; + uint16_t num_channels = 1; + uint32_t sample_rate = 0; + uint32_t byte_rate = 0; + uint16_t block_align = 0; + uint16_t bits_per_sample = sizeof(WavDataType) * 8; + + // DATA 头 + char data[4] = {'d', 'a', 't', 'a'}; + uint32_t data_size = 0; + }; + + enum WavAudioFormat { + WAV_FORMAT_16BIT_PCM = 1, // 16-bit PCM 格式 + WAV_FORMAT_32BIT_FLOAT = 3 // 32-bit IEEE float 格式 + }; + + // 返回值通过模板特化由 WavDataType 决定 + inline uint16_t GetWavAudioFormat(); + + bool Init( + const std::string &AMModelPath, + const std::string &VOCModelPath, + PowerMode cpuPowerMode, + int cpuThreadNum, + // WAV采样率(必须与模型输出匹配) + // 如果播放速度和音调异常,请修改采样率 + // 常见采样率:16000, 24000, 32000, 44100, 48000, 96000 + uint32_t wavSampleRate + ) { // Release model if exists ReleaseModel(); @@ -33,6 +66,8 @@ public: return false; } + wav_sample_rate_ = wavSampleRate; + return true; } @@ -41,7 +76,7 @@ public: ReleaseWav(); } - std::shared_ptr LoadModel(const std::string &modelPath, int cpuThreadNum, const std::string &cpuPowerMode) { + std::shared_ptr LoadModel(const std::string &modelPath, int cpuThreadNum, PowerMode cpuPowerMode) { if (modelPath.empty()) { return nullptr; } @@ -50,23 +85,7 @@ public: MobileConfig config; config.set_model_from_file(modelPath); config.set_threads(cpuThreadNum); - - if (cpuPowerMode == "LITE_POWER_HIGH") { - config.set_power_mode(PowerMode::LITE_POWER_HIGH); - } else if (cpuPowerMode == "LITE_POWER_LOW") { - config.set_power_mode(PowerMode::LITE_POWER_LOW); - } else if (cpuPowerMode == "LITE_POWER_FULL") { - config.set_power_mode(PowerMode::LITE_POWER_FULL); - } else if (cpuPowerMode == "LITE_POWER_NO_BIND") { - config.set_power_mode(PowerMode::LITE_POWER_NO_BIND); - } else if (cpuPowerMode == "LITE_POWER_RAND_HIGH") { - config.set_power_mode(PowerMode::LITE_POWER_RAND_HIGH); - } else if (cpuPowerMode == "LITE_POWER_RAND_LOW") { - config.set_power_mode(PowerMode::LITE_POWER_RAND_LOW); - } else { - std::cerr << "Unknown cpu power mode!" << std::endl; - return nullptr; - } + config.set_power_mode(cpuPowerMode); return CreatePaddlePredictor(config); } @@ -85,7 +104,7 @@ public: auto start = std::chrono::system_clock::now(); // 执行推理 - VOCOutputToWav(GetAMOutput(phones)); + VOCOutputToWav(GetVOCOutput(GetAMOutput(phones))); // 计时结束 auto end = std::chrono::system_clock::now(); @@ -116,12 +135,12 @@ public: return am_output_handle; } - void VOCOutputToWav(std::unique_ptr &&input) { + std::unique_ptr GetVOCOutput(std::unique_ptr &&amOutput) { auto mel_handle = VOC_predictor_->GetInput(0); // [?, 80] - auto dims = input->shape(); + auto dims = amOutput->shape(); mel_handle->Resize(dims); - auto am_output_data = input->mutable_data(); + auto am_output_data = amOutput->mutable_data(); mel_handle->CopyFromCpu(am_output_data); VOC_predictor_->Run(); @@ -135,12 +154,16 @@ public: } std::cout << std::endl; + return voc_output_handle; + } + + void VOCOutputToWav(std::unique_ptr &&vocOutput) { // 获取输出Tensor的数据 int64_t output_size = 1; - for (auto dim : voc_output_handle->shape()) { + for (auto dim : vocOutput->shape()) { output_size *= dim; } - auto output_data = voc_output_handle->mutable_data(); + auto output_data = vocOutput->mutable_data(); SaveFloatWav(output_data, output_size); } @@ -149,21 +172,7 @@ public: return (number < 0) ? -number : number; } - void SaveFloatWav(float *floatWav, int64_t size) { - wav_.resize(size); - float maxSample = 0.01; - // 寻找最大采样值 - for (int64_t i=0; i maxSample) { - maxSample = sample; - } - } - // 把采样值缩放到 int_16 范围 - for (int64_t i=0; i(GetWavSize()) / sizeof(WavDataType) / static_cast(WAV_SAMPLE_RATE) * 1000; + return static_cast(GetWavSize()) / sizeof(WavDataType) / static_cast(wav_sample_rate_) * 1000; } // 获取RTF(合成时间 / 音频时长) @@ -195,32 +204,6 @@ public: wav_.clear(); } - struct WavHeader { - // RIFF 头 - char riff[4] = {'R', 'I', 'F', 'F'}; - uint32_t size = 0; - char wave[4] = {'W', 'A', 'V', 'E'}; - - // FMT 头 - char fmt[4] = {'f', 'm', 't', ' '}; - uint32_t fmt_size = 16; - uint16_t audio_format = 1; // 1为整数编码,3为浮点编码 - uint16_t num_channels = 1; - - // WAV采样率 - // 如果播放速度和音调异常,请修改采样率 - // 常见采样率:16000, 24000, 32000, 44100, 48000, 96000 - uint32_t sample_rate = WAV_SAMPLE_RATE; - - uint32_t byte_rate = 64000; - uint16_t block_align = 2; - uint16_t bits_per_sample = sizeof(WavDataType) * 8; - - // DATA 头 - char data[4] = {'d', 'a', 't', 'a'}; - uint32_t data_size = 0; - }; - bool WriteWavToFile(const std::string &wavPath) { std::ofstream fout(wavPath, std::ios::binary); if (!fout.is_open()) { @@ -229,8 +212,10 @@ public: // 写入头信息 WavHeader header; + header.audio_format = GetWavAudioFormat(); header.data_size = GetWavSize(); header.size = sizeof(header) - 8 + header.data_size; + header.sample_rate = wav_sample_rate_; header.byte_rate = header.sample_rate * header.num_channels * header.bits_per_sample / 8; header.block_align = header.num_channels * header.bits_per_sample / 8; fout.write(reinterpret_cast(&header), sizeof(header)); @@ -244,7 +229,43 @@ public: private: float inference_time_ = 0; + uint32_t wav_sample_rate_ = 0; + std::vector wav_; std::shared_ptr AM_predictor_ = nullptr; std::shared_ptr VOC_predictor_ = nullptr; - std::vector wav_; }; + +template<> +uint16_t Predictor::GetWavAudioFormat() { + return Predictor::WAV_FORMAT_16BIT_PCM; +} + +template<> +uint16_t Predictor::GetWavAudioFormat() { + return Predictor::WAV_FORMAT_32BIT_FLOAT; +} + +// 保存 16-bit PCM 格式 WAV +template<> +void Predictor::SaveFloatWav(float *floatWav, int64_t size) { + wav_.resize(size); + float maxSample = 0.01; + // 寻找最大采样值 + for (int64_t i=0; i maxSample) { + maxSample = sample; + } + } + // 把采样值缩放到 int_16 范围 + for (int64_t i=0; i +void Predictor::SaveFloatWav(float *floatWav, int64_t size) { + wav_.resize(size); + std::copy_n(floatWav, size, wav_.data()); +} diff --git a/demos/TTSArmLinux/src/main.cc b/demos/TTSArmLinux/src/main.cc index 6a430e4cc..2285b28b3 100644 --- a/demos/TTSArmLinux/src/main.cc +++ b/demos/TTSArmLinux/src/main.cc @@ -49,8 +49,24 @@ int main(int argc, char *argv[]) { return -1; } - Predictor predictor; - if (!predictor.Init(AMModelPath, VOCModelPath, 1, "LITE_POWER_HIGH")) { + // 模板参数:WAV数据类型 + // 可在 int16_t 和 float 之间切换, + // 用于生成 16-bit PCM 或 32-bit IEEE float 格式的 WAV + Predictor predictor; + //Predictor predictor; + + // WAV采样率(必须与模型输出匹配) + // 如果播放速度和音调异常,请修改采样率 + // 常见采样率:16000, 24000, 32000, 44100, 48000, 96000 + const uint32_t wavSampleRate = 24000; + + // CPU线程数 + const int cpuThreadNum = 1; + + // CPU电源模式 + const PowerMode cpuPowerMode = PowerMode::LITE_POWER_HIGH; + + if (!predictor.Init(AMModelPath, VOCModelPath, cpuPowerMode, cpuThreadNum, wavSampleRate)) { std::cerr << "predictor init failed" << std::endl; return -1; }