fix TTSArmLinux

pull/3030/head
TianYuan 3 years ago
parent 6322cea1a4
commit 96038fb01b

@ -1,7 +1,20 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm> #include <algorithm>
#include <chrono> #include <chrono>
#include <iostream>
#include <fstream> #include <fstream>
#include <iostream>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
@ -10,24 +23,28 @@
using namespace paddle::lite_api; using namespace paddle::lite_api;
class PredictorInterface { class PredictorInterface {
public: public:
virtual ~PredictorInterface() = 0; virtual ~PredictorInterface() = 0;
virtual bool Init( virtual bool Init(const std::string &AcousticModelPath,
const std::string &AcousticModelPath, const std::string &VocoderPath,
const std::string &VocoderPath, PowerMode cpuPowerMode,
PowerMode cpuPowerMode, int cpuThreadNum,
int cpuThreadNum, // WAV采样率必须与模型输出匹配
// WAV采样率必须与模型输出匹配 // 如果播放速度和音调异常,请修改采样率
// 如果播放速度和音调异常,请修改采样率 // 常见采样率16000, 24000, 32000, 44100, 48000, 96000
// 常见采样率16000, 24000, 32000, 44100, 48000, 96000 uint32_t wavSampleRate) = 0;
uint32_t wavSampleRate virtual std::shared_ptr<PaddlePredictor> LoadModel(
) = 0; const std::string &modelPath,
virtual std::shared_ptr<PaddlePredictor> LoadModel(const std::string &modelPath, int cpuThreadNum, PowerMode cpuPowerMode) = 0; int cpuThreadNum,
PowerMode cpuPowerMode) = 0;
virtual void ReleaseModel() = 0; virtual void ReleaseModel() = 0;
virtual bool RunModel(const std::vector<int64_t> &phones) = 0; virtual bool RunModel(const std::vector<int64_t> &phones) = 0;
virtual std::unique_ptr<const Tensor> GetAcousticModelOutput(const std::vector<int64_t> &phones) = 0; virtual std::unique_ptr<const Tensor> GetAcousticModelOutput(
virtual std::unique_ptr<const Tensor> GetVocoderOutput(std::unique_ptr<const Tensor> &&amOutput) = 0; const std::vector<int64_t> &phones) = 0;
virtual void VocoderOutputToWav(std::unique_ptr<const Tensor> &&vocOutput) = 0; virtual std::unique_ptr<const Tensor> GetVocoderOutput(
std::unique_ptr<const Tensor> &&amOutput) = 0;
virtual void VocoderOutputToWav(
std::unique_ptr<const Tensor> &&vocOutput) = 0;
virtual void SaveFloatWav(float *floatWav, int64_t size) = 0; virtual void SaveFloatWav(float *floatWav, int64_t size) = 0;
virtual bool IsLoaded() = 0; virtual bool IsLoaded() = 0;
virtual float GetInferenceTime() = 0; virtual float GetInferenceTime() = 0;
@ -45,23 +62,22 @@ PredictorInterface::~PredictorInterface() {}
// WavDataType: WAV数据类型 // WavDataType: WAV数据类型
// 可在 int16_t 和 float 之间切换, // 可在 int16_t 和 float 之间切换,
// 用于生成 16-bit PCM 或 32-bit IEEE float 格式的 WAV // 用于生成 16-bit PCM 或 32-bit IEEE float 格式的 WAV
template<typename WavDataType> template <typename WavDataType>
class Predictor : public PredictorInterface { class Predictor : public PredictorInterface {
public: public:
virtual bool Init( virtual bool Init(const std::string &AcousticModelPath,
const std::string &AcousticModelPath, const std::string &VocoderPath,
const std::string &VocoderPath, PowerMode cpuPowerMode,
PowerMode cpuPowerMode, int cpuThreadNum,
int cpuThreadNum, // WAV采样率必须与模型输出匹配
// WAV采样率必须与模型输出匹配 // 如果播放速度和音调异常,请修改采样率
// 如果播放速度和音调异常,请修改采样率 // 常见采样率16000, 24000, 32000, 44100, 48000, 96000
// 常见采样率16000, 24000, 32000, 44100, 48000, 96000 uint32_t wavSampleRate) override {
uint32_t wavSampleRate
) override {
// Release model if exists // Release model if exists
ReleaseModel(); ReleaseModel();
acoustic_model_predictor_ = LoadModel(AcousticModelPath, cpuThreadNum, cpuPowerMode); acoustic_model_predictor_ =
LoadModel(AcousticModelPath, cpuThreadNum, cpuPowerMode);
if (acoustic_model_predictor_ == nullptr) { if (acoustic_model_predictor_ == nullptr) {
return false; return false;
} }
@ -80,7 +96,10 @@ public:
ReleaseWav(); ReleaseWav();
} }
virtual std::shared_ptr<PaddlePredictor> LoadModel(const std::string &modelPath, int cpuThreadNum, PowerMode cpuPowerMode) override { virtual std::shared_ptr<PaddlePredictor> LoadModel(
const std::string &modelPath,
int cpuThreadNum,
PowerMode cpuPowerMode) override {
if (modelPath.empty()) { if (modelPath.empty()) {
return nullptr; return nullptr;
} }
@ -115,12 +134,13 @@ public:
// 计算用时 // 计算用时
std::chrono::duration<float> duration = end - start; std::chrono::duration<float> duration = end - start;
inference_time_ = duration.count() * 1000; // 单位:毫秒 inference_time_ = duration.count() * 1000; // 单位:毫秒
return true; return true;
} }
virtual std::unique_ptr<const Tensor> GetAcousticModelOutput(const std::vector<int64_t> &phones) override { virtual std::unique_ptr<const Tensor> GetAcousticModelOutput(
const std::vector<int64_t> &phones) override {
auto phones_handle = acoustic_model_predictor_->GetInput(0); auto phones_handle = acoustic_model_predictor_->GetInput(0);
phones_handle->Resize({static_cast<int64_t>(phones.size())}); phones_handle->Resize({static_cast<int64_t>(phones.size())});
phones_handle->CopyFromCpu(phones.data()); phones_handle->CopyFromCpu(phones.data());
@ -139,7 +159,8 @@ public:
return am_output_handle; return am_output_handle;
} }
virtual std::unique_ptr<const Tensor> GetVocoderOutput(std::unique_ptr<const Tensor> &&amOutput) override { virtual std::unique_ptr<const Tensor> GetVocoderOutput(
std::unique_ptr<const Tensor> &&amOutput) override {
auto mel_handle = vocoder_predictor_->GetInput(0); auto mel_handle = vocoder_predictor_->GetInput(0);
// [?, 80] // [?, 80]
auto dims = amOutput->shape(); auto dims = amOutput->shape();
@ -161,7 +182,8 @@ public:
return voc_output_handle; return voc_output_handle;
} }
virtual void VocoderOutputToWav(std::unique_ptr<const Tensor> &&vocOutput) override { virtual void VocoderOutputToWav(
std::unique_ptr<const Tensor> &&vocOutput) override {
// 获取输出Tensor的数据 // 获取输出Tensor的数据
int64_t output_size = 1; int64_t output_size = 1;
for (auto dim : vocOutput->shape()) { for (auto dim : vocOutput->shape()) {
@ -175,16 +197,13 @@ public:
virtual void SaveFloatWav(float *floatWav, int64_t size) override; virtual void SaveFloatWav(float *floatWav, int64_t size) override;
virtual bool IsLoaded() override { virtual bool IsLoaded() override {
return acoustic_model_predictor_ != nullptr && vocoder_predictor_ != nullptr; return acoustic_model_predictor_ != nullptr &&
vocoder_predictor_ != nullptr;
} }
virtual float GetInferenceTime() override { virtual float GetInferenceTime() override { return inference_time_; }
return inference_time_;
}
const std::vector<WavDataType> & GetWav() { const std::vector<WavDataType> &GetWav() { return wav_; }
return wav_;
}
virtual int GetWavSize() override { virtual int GetWavSize() override {
return wav_.size() * sizeof(WavDataType); return wav_.size() * sizeof(WavDataType);
@ -192,7 +211,8 @@ public:
// 获取WAV持续时间单位毫秒 // 获取WAV持续时间单位毫秒
virtual float GetWavDuration() override { virtual float GetWavDuration() override {
return static_cast<float>(GetWavSize()) / sizeof(WavDataType) / static_cast<float>(wav_sample_rate_) * 1000; return static_cast<float>(GetWavSize()) / sizeof(WavDataType) /
static_cast<float>(wav_sample_rate_) * 1000;
} }
// 获取RTF合成时间 / 音频时长) // 获取RTF合成时间 / 音频时长)
@ -200,9 +220,7 @@ public:
return GetInferenceTime() / GetWavDuration(); return GetInferenceTime() / GetWavDuration();
} }
virtual void ReleaseWav() override { virtual void ReleaseWav() override { wav_.clear(); }
wav_.clear();
}
virtual bool WriteWavToFile(const std::string &wavPath) override { virtual bool WriteWavToFile(const std::string &wavPath) override {
std::ofstream fout(wavPath, std::ios::binary); std::ofstream fout(wavPath, std::ios::binary);
@ -216,18 +234,20 @@ public:
header.data_size = GetWavSize(); header.data_size = GetWavSize();
header.size = sizeof(header) - 8 + header.data_size; header.size = sizeof(header) - 8 + header.data_size;
header.sample_rate = wav_sample_rate_; header.sample_rate = wav_sample_rate_;
header.byte_rate = header.sample_rate * header.num_channels * header.bits_per_sample / 8; header.byte_rate = header.sample_rate * header.num_channels *
header.bits_per_sample / 8;
header.block_align = header.num_channels * header.bits_per_sample / 8; header.block_align = header.num_channels * header.bits_per_sample / 8;
fout.write(reinterpret_cast<const char*>(&header), sizeof(header)); fout.write(reinterpret_cast<const char *>(&header), sizeof(header));
// 写入wav数据 // 写入wav数据
fout.write(reinterpret_cast<const char*>(wav_.data()), header.data_size); fout.write(reinterpret_cast<const char *>(wav_.data()),
header.data_size);
fout.close(); fout.close();
return true; return true;
} }
protected: protected:
struct WavHeader { struct WavHeader {
// RIFF 头 // RIFF 头
char riff[4] = {'R', 'I', 'F', 'F'}; char riff[4] = {'R', 'I', 'F', 'F'};
@ -250,19 +270,17 @@ protected:
}; };
enum WavAudioFormat { enum WavAudioFormat {
WAV_FORMAT_16BIT_PCM = 1, // 16-bit PCM 格式 WAV_FORMAT_16BIT_PCM = 1, // 16-bit PCM 格式
WAV_FORMAT_32BIT_FLOAT = 3 // 32-bit IEEE float 格式 WAV_FORMAT_32BIT_FLOAT = 3 // 32-bit IEEE float 格式
}; };
protected: protected:
// 返回值通过模板特化由 WavDataType 决定 // 返回值通过模板特化由 WavDataType 决定
inline uint16_t GetWavAudioFormat(); inline uint16_t GetWavAudioFormat();
inline float Abs(float number) { inline float Abs(float number) { return (number < 0) ? -number : number; }
return (number < 0) ? -number : number;
}
protected: protected:
float inference_time_ = 0; float inference_time_ = 0;
uint32_t wav_sample_rate_ = 0; uint32_t wav_sample_rate_ = 0;
std::vector<WavDataType> wav_; std::vector<WavDataType> wav_;
@ -270,36 +288,36 @@ protected:
std::shared_ptr<PaddlePredictor> vocoder_predictor_ = nullptr; std::shared_ptr<PaddlePredictor> vocoder_predictor_ = nullptr;
}; };
template<> template <>
uint16_t Predictor<int16_t>::GetWavAudioFormat() { uint16_t Predictor<int16_t>::GetWavAudioFormat() {
return Predictor::WAV_FORMAT_16BIT_PCM; return Predictor::WAV_FORMAT_16BIT_PCM;
} }
template<> template <>
uint16_t Predictor<float>::GetWavAudioFormat() { uint16_t Predictor<float>::GetWavAudioFormat() {
return Predictor::WAV_FORMAT_32BIT_FLOAT; return Predictor::WAV_FORMAT_32BIT_FLOAT;
} }
// 保存 16-bit PCM 格式 WAV // 保存 16-bit PCM 格式 WAV
template<> template <>
void Predictor<int16_t>::SaveFloatWav(float *floatWav, int64_t size) { void Predictor<int16_t>::SaveFloatWav(float *floatWav, int64_t size) {
wav_.resize(size); wav_.resize(size);
float maxSample = 0.01; float maxSample = 0.01;
// 寻找最大采样值 // 寻找最大采样值
for (int64_t i=0; i<size; i++) { for (int64_t i = 0; i < size; i++) {
float sample = Abs(floatWav[i]); float sample = Abs(floatWav[i]);
if (sample > maxSample) { if (sample > maxSample) {
maxSample = sample; maxSample = sample;
} }
} }
// 把采样值缩放到 int_16 范围 // 把采样值缩放到 int_16 范围
for (int64_t i=0; i<size; i++) { for (int64_t i = 0; i < size; i++) {
wav_[i] = floatWav[i] * 32767.0f / maxSample; wav_[i] = floatWav[i] * 32767.0f / maxSample;
} }
} }
// 保存 32-bit IEEE float 格式 WAV // 保存 32-bit IEEE float 格式 WAV
template<> template <>
void Predictor<float>::SaveFloatWav(float *floatWav, int64_t size) { void Predictor<float>::SaveFloatWav(float *floatWav, int64_t size) {
wav_.resize(size); wav_.resize(size);
std::copy_n(floatWav, size, wav_.data()); std::copy_n(floatWav, size, wav_.data());

@ -1,23 +1,48 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <front/front_interface.h>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <paddle_api.h>
#include <cstdlib> #include <cstdlib>
#include <iostream> #include <iostream>
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <map>
#include <glog/logging.h>
#include <gflags/gflags.h>
#include <paddle_api.h>
#include <front/front_interface.h>
#include "Predictor.hpp" #include "Predictor.hpp"
using namespace paddle::lite_api; using namespace paddle::lite_api;
DEFINE_string(sentence, "你好,欢迎使用语音合成服务", "Text to be synthesized (Chinese only. English will crash the program.)"); DEFINE_string(
sentence,
"你好,欢迎使用语音合成服务",
"Text to be synthesized (Chinese only. English will crash the program.)");
DEFINE_string(front_conf, "./front.conf", "Front configuration file"); DEFINE_string(front_conf, "./front.conf", "Front configuration file");
DEFINE_string(acoustic_model, "./models/cpu/fastspeech2_csmsc_arm.nb", "Acoustic model .nb file"); DEFINE_string(acoustic_model,
DEFINE_string(vocoder, "./models/cpu/fastspeech2_csmsc_arm.nb", "vocoder .nb file"); "./models/cpu/fastspeech2_csmsc_arm.nb",
"Acoustic model .nb file");
DEFINE_string(vocoder,
"./models/cpu/fastspeech2_csmsc_arm.nb",
"vocoder .nb file");
DEFINE_string(output_wav, "./output/tts.wav", "Output WAV file"); DEFINE_string(output_wav, "./output/tts.wav", "Output WAV file");
DEFINE_string(wav_bit_depth, "16", "WAV bit depth, 16 (16-bit PCM) or 32 (32-bit IEEE float)"); DEFINE_string(wav_bit_depth,
DEFINE_string(wav_sample_rate, "24000", "WAV sample rate, should match the output of the vocoder"); "16",
"WAV bit depth, 16 (16-bit PCM) or 32 (32-bit IEEE float)");
DEFINE_string(wav_sample_rate,
"24000",
"WAV sample rate, should match the output of the vocoder");
DEFINE_string(cpu_thread, "1", "CPU thread numbers"); DEFINE_string(cpu_thread, "1", "CPU thread numbers");
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
@ -53,7 +78,7 @@ int main(int argc, char *argv[]) {
// 繁体转简体 // 繁体转简体
std::wstring sentence_simp; std::wstring sentence_simp;
front_inst->Trand2Simp(ws_sentence, sentence_simp); front_inst->Trand2Simp(ws_sentence, &sentence_simp);
ws_sentence = sentence_simp; ws_sentence = sentence_simp;
std::string s_sentence; std::string s_sentence;
@ -63,26 +88,28 @@ int main(int argc, char *argv[]) {
// 根据标点进行分句 // 根据标点进行分句
LOG(INFO) << "Start to segment sentences by punctuation"; LOG(INFO) << "Start to segment sentences by punctuation";
front_inst->SplitByPunc(ws_sentence, sentence_part); front_inst->SplitByPunc(ws_sentence, &sentence_part);
LOG(INFO) << "Segment sentences through punctuation successfully"; LOG(INFO) << "Segment sentences through punctuation successfully";
// 分句后获取音素id // 分句后获取音素id
LOG(INFO) << "Start to get the phoneme and tone id sequence of each sentence"; LOG(INFO)
for(int i = 0; i < sentence_part.size(); i++) { << "Start to get the phoneme and tone id sequence of each sentence";
for (int i = 0; i < sentence_part.size(); i++) {
LOG(INFO) << "Raw sentence is: " << ppspeech::wstring2utf8string(sentence_part[i]); LOG(INFO) << "Raw sentence is: "
front_inst->SentenceNormalize(sentence_part[i]); << ppspeech::wstring2utf8string(sentence_part[i]);
front_inst->SentenceNormalize(&sentence_part[i]);
s_sentence = ppspeech::wstring2utf8string(sentence_part[i]); s_sentence = ppspeech::wstring2utf8string(sentence_part[i]);
LOG(INFO) << "After normalization sentence is: " << s_sentence; LOG(INFO) << "After normalization sentence is: " << s_sentence;
if (0 != front_inst->GetSentenceIds(s_sentence, phoneids, toneids)) { if (0 != front_inst->GetSentenceIds(s_sentence, &phoneids, &toneids)) {
LOG(ERROR) << "TTS inst get sentence phoneids and toneids failed"; LOG(ERROR) << "TTS inst get sentence phoneids and toneids failed";
return -1; return -1;
} }
} }
LOG(INFO) << "The phoneids of the sentence is: " << limonp::Join(phoneids.begin(), phoneids.end(), " "); LOG(INFO) << "The phoneids of the sentence is: "
LOG(INFO) << "The toneids of the sentence is: " << limonp::Join(toneids.begin(), toneids.end(), " "); << limonp::Join(phoneids.begin(), phoneids.end(), " ");
LOG(INFO) << "The toneids of the sentence is: "
<< limonp::Join(toneids.begin(), toneids.end(), " ");
LOG(INFO) << "Get the phoneme id sequence of each sentence successfully"; LOG(INFO) << "Get the phoneme id sequence of each sentence successfully";
@ -99,13 +126,19 @@ int main(int argc, char *argv[]) {
// CPU电源模式 // CPU电源模式
const PowerMode cpuPowerMode = PowerMode::LITE_POWER_HIGH; const PowerMode cpuPowerMode = PowerMode::LITE_POWER_HIGH;
if (!predictor->Init(FLAGS_acoustic_model, FLAGS_vocoder, cpuPowerMode, cpuThreadNum, wavSampleRate)) { if (!predictor->Init(FLAGS_acoustic_model,
FLAGS_vocoder,
cpuPowerMode,
cpuThreadNum,
wavSampleRate)) {
LOG(ERROR) << "predictor init failed" << std::endl; LOG(ERROR) << "predictor init failed" << std::endl;
return -1; return -1;
} }
std::vector<int64_t> phones(phoneids.size()); std::vector<int64_t> phones(phoneids.size());
std::transform(phoneids.begin(), phoneids.end(), phones.begin(), [](int x) { return static_cast<int64_t>(x); }); std::transform(phoneids.begin(), phoneids.end(), phones.begin(), [](int x) {
return static_cast<int64_t>(x);
});
if (!predictor->RunModel(phones)) { if (!predictor->RunModel(phones)) {
LOG(ERROR) << "predictor run model failed" << std::endl; LOG(ERROR) << "predictor run model failed" << std::endl;
@ -113,7 +146,8 @@ int main(int argc, char *argv[]) {
} }
LOG(INFO) << "Inference time: " << predictor->GetInferenceTime() << " ms, " LOG(INFO) << "Inference time: " << predictor->GetInferenceTime() << " ms, "
<< "WAV size (without header): " << predictor->GetWavSize() << " bytes, " << "WAV size (without header): " << predictor->GetWavSize()
<< " bytes, "
<< "WAV duration: " << predictor->GetWavDuration() << " ms, " << "WAV duration: " << predictor->GetWavDuration() << " ms, "
<< "RTF: " << predictor->GetRTF() << std::endl; << "RTF: " << predictor->GetRTF() << std::endl;

Loading…
Cancel
Save