refac: [demos/TTSArmLinux] support 16-bit PCM and 32-bit IEEE float WAV output with class template

pull/3018/head
彭逸豪 3 years ago
parent 2772022432
commit e5fb491f1d

@ -9,18 +9,51 @@
using namespace paddle::lite_api;
// WAV采样率
// 如果播放速度和音调异常,请修改采样率
// 常见采样率16000, 24000, 32000, 44100, 48000, 96000
#define WAV_SAMPLE_RATE 24000
// WAV数据类型
// 定义在此处以便在 int16_t 和 float 之间切换
typedef int16_t WavDataType;
// WavDataType: WAV数据类型
// 可在 int16_t 和 float 之间切换,
// 用于生成 16-bit PCM 或 32-bit IEEE float 格式的 WAV
template<typename WavDataType>
class Predictor {
public:
bool Init(const std::string &AMModelPath, const std::string &VOCModelPath, int cpuThreadNum, const std::string &cpuPowerMode) {
struct WavHeader {
// RIFF 头
char riff[4] = {'R', 'I', 'F', 'F'};
uint32_t size = 0;
char wave[4] = {'W', 'A', 'V', 'E'};
// FMT 头
char fmt[4] = {'f', 'm', 't', ' '};
uint32_t fmt_size = 16;
uint16_t audio_format = 0;
uint16_t num_channels = 1;
uint32_t sample_rate = 0;
uint32_t byte_rate = 0;
uint16_t block_align = 0;
uint16_t bits_per_sample = sizeof(WavDataType) * 8;
// DATA 头
char data[4] = {'d', 'a', 't', 'a'};
uint32_t data_size = 0;
};
enum WavAudioFormat {
WAV_FORMAT_16BIT_PCM = 1, // 16-bit PCM 格式
WAV_FORMAT_32BIT_FLOAT = 3 // 32-bit IEEE float 格式
};
// 返回值通过模板特化由 WavDataType 决定
inline uint16_t GetWavAudioFormat();
bool Init(
const std::string &AMModelPath,
const std::string &VOCModelPath,
PowerMode cpuPowerMode,
int cpuThreadNum,
// WAV采样率必须与模型输出匹配
// 如果播放速度和音调异常,请修改采样率
// 常见采样率16000, 24000, 32000, 44100, 48000, 96000
uint32_t wavSampleRate
) {
// Release model if exists
ReleaseModel();
@ -33,6 +66,8 @@ public:
return false;
}
wav_sample_rate_ = wavSampleRate;
return true;
}
@ -41,7 +76,7 @@ public:
ReleaseWav();
}
std::shared_ptr<PaddlePredictor> LoadModel(const std::string &modelPath, int cpuThreadNum, const std::string &cpuPowerMode) {
std::shared_ptr<PaddlePredictor> LoadModel(const std::string &modelPath, int cpuThreadNum, PowerMode cpuPowerMode) {
if (modelPath.empty()) {
return nullptr;
}
@ -50,23 +85,7 @@ public:
MobileConfig config;
config.set_model_from_file(modelPath);
config.set_threads(cpuThreadNum);
if (cpuPowerMode == "LITE_POWER_HIGH") {
config.set_power_mode(PowerMode::LITE_POWER_HIGH);
} else if (cpuPowerMode == "LITE_POWER_LOW") {
config.set_power_mode(PowerMode::LITE_POWER_LOW);
} else if (cpuPowerMode == "LITE_POWER_FULL") {
config.set_power_mode(PowerMode::LITE_POWER_FULL);
} else if (cpuPowerMode == "LITE_POWER_NO_BIND") {
config.set_power_mode(PowerMode::LITE_POWER_NO_BIND);
} else if (cpuPowerMode == "LITE_POWER_RAND_HIGH") {
config.set_power_mode(PowerMode::LITE_POWER_RAND_HIGH);
} else if (cpuPowerMode == "LITE_POWER_RAND_LOW") {
config.set_power_mode(PowerMode::LITE_POWER_RAND_LOW);
} else {
std::cerr << "Unknown cpu power mode!" << std::endl;
return nullptr;
}
config.set_power_mode(cpuPowerMode);
return CreatePaddlePredictor<MobileConfig>(config);
}
@ -85,7 +104,7 @@ public:
auto start = std::chrono::system_clock::now();
// 执行推理
VOCOutputToWav(GetAMOutput(phones));
VOCOutputToWav(GetVOCOutput(GetAMOutput(phones)));
// 计时结束
auto end = std::chrono::system_clock::now();
@ -116,12 +135,12 @@ public:
return am_output_handle;
}
void VOCOutputToWav(std::unique_ptr<const Tensor> &&input) {
std::unique_ptr<const Tensor> GetVOCOutput(std::unique_ptr<const Tensor> &&amOutput) {
auto mel_handle = VOC_predictor_->GetInput(0);
// [?, 80]
auto dims = input->shape();
auto dims = amOutput->shape();
mel_handle->Resize(dims);
auto am_output_data = input->mutable_data<float>();
auto am_output_data = amOutput->mutable_data<float>();
mel_handle->CopyFromCpu(am_output_data);
VOC_predictor_->Run();
@ -135,12 +154,16 @@ public:
}
std::cout << std::endl;
return voc_output_handle;
}
void VOCOutputToWav(std::unique_ptr<const Tensor> &&vocOutput) {
// 获取输出Tensor的数据
int64_t output_size = 1;
for (auto dim : voc_output_handle->shape()) {
for (auto dim : vocOutput->shape()) {
output_size *= dim;
}
auto output_data = voc_output_handle->mutable_data<float>();
auto output_data = vocOutput->mutable_data<float>();
SaveFloatWav(output_data, output_size);
}
@ -149,21 +172,7 @@ public:
return (number < 0) ? -number : number;
}
void SaveFloatWav(float *floatWav, int64_t size) {
wav_.resize(size);
float maxSample = 0.01;
// 寻找最大采样值
for (int64_t i=0; i<size; i++) {
float sample = Abs(floatWav[i]);
if (sample > maxSample) {
maxSample = sample;
}
}
// 把采样值缩放到 int_16 范围
for (int64_t i=0; i<size; i++) {
wav_[i] = floatWav[i] * 32767.0f / maxSample;
}
}
void SaveFloatWav(float *floatWav, int64_t size);
bool IsLoaded() {
return AM_predictor_ != nullptr && VOC_predictor_ != nullptr;
@ -183,7 +192,7 @@ public:
// 获取WAV持续时间单位毫秒
float GetWavDuration() {
return static_cast<float>(GetWavSize()) / sizeof(WavDataType) / static_cast<float>(WAV_SAMPLE_RATE) * 1000;
return static_cast<float>(GetWavSize()) / sizeof(WavDataType) / static_cast<float>(wav_sample_rate_) * 1000;
}
// 获取RTF合成时间 / 音频时长)
@ -195,32 +204,6 @@ public:
wav_.clear();
}
struct WavHeader {
// RIFF 头
char riff[4] = {'R', 'I', 'F', 'F'};
uint32_t size = 0;
char wave[4] = {'W', 'A', 'V', 'E'};
// FMT 头
char fmt[4] = {'f', 'm', 't', ' '};
uint32_t fmt_size = 16;
uint16_t audio_format = 1; // 1为整数编码3为浮点编码
uint16_t num_channels = 1;
// WAV采样率
// 如果播放速度和音调异常,请修改采样率
// 常见采样率16000, 24000, 32000, 44100, 48000, 96000
uint32_t sample_rate = WAV_SAMPLE_RATE;
uint32_t byte_rate = 64000;
uint16_t block_align = 2;
uint16_t bits_per_sample = sizeof(WavDataType) * 8;
// DATA 头
char data[4] = {'d', 'a', 't', 'a'};
uint32_t data_size = 0;
};
bool WriteWavToFile(const std::string &wavPath) {
std::ofstream fout(wavPath, std::ios::binary);
if (!fout.is_open()) {
@ -229,8 +212,10 @@ public:
// 写入头信息
WavHeader header;
header.audio_format = GetWavAudioFormat();
header.data_size = GetWavSize();
header.size = sizeof(header) - 8 + header.data_size;
header.sample_rate = wav_sample_rate_;
header.byte_rate = header.sample_rate * header.num_channels * header.bits_per_sample / 8;
header.block_align = header.num_channels * header.bits_per_sample / 8;
fout.write(reinterpret_cast<const char*>(&header), sizeof(header));
@ -244,7 +229,43 @@ public:
private:
float inference_time_ = 0;
uint32_t wav_sample_rate_ = 0;
std::vector<WavDataType> wav_;
std::shared_ptr<PaddlePredictor> AM_predictor_ = nullptr;
std::shared_ptr<PaddlePredictor> VOC_predictor_ = nullptr;
std::vector<WavDataType> wav_;
};
template<>
uint16_t Predictor<int16_t>::GetWavAudioFormat() {
return Predictor::WAV_FORMAT_16BIT_PCM;
}
template<>
uint16_t Predictor<float>::GetWavAudioFormat() {
return Predictor::WAV_FORMAT_32BIT_FLOAT;
}
// 保存 16-bit PCM 格式 WAV
template<>
void Predictor<int16_t>::SaveFloatWav(float *floatWav, int64_t size) {
wav_.resize(size);
float maxSample = 0.01;
// 寻找最大采样值
for (int64_t i=0; i<size; i++) {
float sample = Abs(floatWav[i]);
if (sample > maxSample) {
maxSample = sample;
}
}
// 把采样值缩放到 int_16 范围
for (int64_t i=0; i<size; i++) {
wav_[i] = floatWav[i] * 32767.0f / maxSample;
}
}
// 保存 32-bit IEEE float 格式 WAV
template<>
void Predictor<float>::SaveFloatWav(float *floatWav, int64_t size) {
wav_.resize(size);
std::copy_n(floatWav, size, wav_.data());
}

@ -49,8 +49,24 @@ int main(int argc, char *argv[]) {
return -1;
}
Predictor predictor;
if (!predictor.Init(AMModelPath, VOCModelPath, 1, "LITE_POWER_HIGH")) {
// 模板参数WAV数据类型
// 可在 int16_t 和 float 之间切换,
// 用于生成 16-bit PCM 或 32-bit IEEE float 格式的 WAV
Predictor<int16_t> predictor;
//Predictor<float> predictor;
// WAV采样率必须与模型输出匹配
// 如果播放速度和音调异常,请修改采样率
// 常见采样率16000, 24000, 32000, 44100, 48000, 96000
const uint32_t wavSampleRate = 24000;
// CPU线程数
const int cpuThreadNum = 1;
// CPU电源模式
const PowerMode cpuPowerMode = PowerMode::LITE_POWER_HIGH;
if (!predictor.Init(AMModelPath, VOCModelPath, cpuPowerMode, cpuThreadNum, wavSampleRate)) {
std::cerr << "predictor init failed" << std::endl;
return -1;
}

Loading…
Cancel
Save