|
|
|
@ -65,14 +65,14 @@ PredictorInterface::~PredictorInterface() {}
|
|
|
|
|
template <typename WavDataType>
|
|
|
|
|
class Predictor : public PredictorInterface {
|
|
|
|
|
public:
|
|
|
|
|
virtual bool Init(const std::string &AcousticModelPath,
|
|
|
|
|
const std::string &VocoderPath,
|
|
|
|
|
PowerMode cpuPowerMode,
|
|
|
|
|
int cpuThreadNum,
|
|
|
|
|
// WAV采样率(必须与模型输出匹配)
|
|
|
|
|
// 如果播放速度和音调异常,请修改采样率
|
|
|
|
|
// 常见采样率:16000, 24000, 32000, 44100, 48000, 96000
|
|
|
|
|
uint32_t wavSampleRate) override {
|
|
|
|
|
bool Init(const std::string &AcousticModelPath,
|
|
|
|
|
const std::string &VocoderPath,
|
|
|
|
|
PowerMode cpuPowerMode,
|
|
|
|
|
int cpuThreadNum,
|
|
|
|
|
// WAV采样率(必须与模型输出匹配)
|
|
|
|
|
// 如果播放速度和音调异常,请修改采样率
|
|
|
|
|
// 常见采样率:16000, 24000, 32000, 44100, 48000, 96000
|
|
|
|
|
uint32_t wavSampleRate) override {
|
|
|
|
|
// Release model if exists
|
|
|
|
|
ReleaseModel();
|
|
|
|
|
|
|
|
|
@ -96,7 +96,7 @@ class Predictor : public PredictorInterface {
|
|
|
|
|
ReleaseWav();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual std::shared_ptr<PaddlePredictor> LoadModel(
|
|
|
|
|
std::shared_ptr<PaddlePredictor> LoadModel(
|
|
|
|
|
const std::string &modelPath,
|
|
|
|
|
int cpuThreadNum,
|
|
|
|
|
PowerMode cpuPowerMode) override {
|
|
|
|
@ -113,12 +113,12 @@ class Predictor : public PredictorInterface {
|
|
|
|
|
return CreatePaddlePredictor<MobileConfig>(config);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual void ReleaseModel() override {
|
|
|
|
|
void ReleaseModel() override {
|
|
|
|
|
acoustic_model_predictor_ = nullptr;
|
|
|
|
|
vocoder_predictor_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual bool RunModel(const std::vector<int64_t> &phones) override {
|
|
|
|
|
bool RunModel(const std::vector<int64_t> &phones) override {
|
|
|
|
|
if (!IsLoaded()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
@ -139,7 +139,7 @@ class Predictor : public PredictorInterface {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual std::unique_ptr<const Tensor> GetAcousticModelOutput(
|
|
|
|
|
std::unique_ptr<const Tensor> GetAcousticModelOutput(
|
|
|
|
|
const std::vector<int64_t> &phones) override {
|
|
|
|
|
auto phones_handle = acoustic_model_predictor_->GetInput(0);
|
|
|
|
|
phones_handle->Resize({static_cast<int64_t>(phones.size())});
|
|
|
|
@ -159,7 +159,7 @@ class Predictor : public PredictorInterface {
|
|
|
|
|
return am_output_handle;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual std::unique_ptr<const Tensor> GetVocoderOutput(
|
|
|
|
|
std::unique_ptr<const Tensor> GetVocoderOutput(
|
|
|
|
|
std::unique_ptr<const Tensor> &&amOutput) override {
|
|
|
|
|
auto mel_handle = vocoder_predictor_->GetInput(0);
|
|
|
|
|
// [?, 80]
|
|
|
|
@ -182,7 +182,7 @@ class Predictor : public PredictorInterface {
|
|
|
|
|
return voc_output_handle;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual void VocoderOutputToWav(
|
|
|
|
|
void VocoderOutputToWav(
|
|
|
|
|
std::unique_ptr<const Tensor> &&vocOutput) override {
|
|
|
|
|
// 获取输出Tensor的数据
|
|
|
|
|
int64_t output_size = 1;
|
|
|
|
@ -194,35 +194,31 @@ class Predictor : public PredictorInterface {
|
|
|
|
|
SaveFloatWav(output_data, output_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual void SaveFloatWav(float *floatWav, int64_t size) override;
|
|
|
|
|
void SaveFloatWav(float *floatWav, int64_t size) override;
|
|
|
|
|
|
|
|
|
|
virtual bool IsLoaded() override {
|
|
|
|
|
bool IsLoaded() override {
|
|
|
|
|
return acoustic_model_predictor_ != nullptr &&
|
|
|
|
|
vocoder_predictor_ != nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual float GetInferenceTime() override { return inference_time_; }
|
|
|
|
|
float GetInferenceTime() override { return inference_time_; }
|
|
|
|
|
|
|
|
|
|
const std::vector<WavDataType> &GetWav() { return wav_; }
|
|
|
|
|
|
|
|
|
|
virtual int GetWavSize() override {
|
|
|
|
|
return wav_.size() * sizeof(WavDataType);
|
|
|
|
|
}
|
|
|
|
|
int GetWavSize() override { return wav_.size() * sizeof(WavDataType); }
|
|
|
|
|
|
|
|
|
|
// 获取WAV持续时间(单位:毫秒)
|
|
|
|
|
virtual float GetWavDuration() override {
|
|
|
|
|
float GetWavDuration() override {
|
|
|
|
|
return static_cast<float>(GetWavSize()) / sizeof(WavDataType) /
|
|
|
|
|
static_cast<float>(wav_sample_rate_) * 1000;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 获取RTF(合成时间 / 音频时长)
|
|
|
|
|
virtual float GetRTF() override {
|
|
|
|
|
return GetInferenceTime() / GetWavDuration();
|
|
|
|
|
}
|
|
|
|
|
float GetRTF() override { return GetInferenceTime() / GetWavDuration(); }
|
|
|
|
|
|
|
|
|
|
virtual void ReleaseWav() override { wav_.clear(); }
|
|
|
|
|
void ReleaseWav() override { wav_.clear(); }
|
|
|
|
|
|
|
|
|
|
virtual bool WriteWavToFile(const std::string &wavPath) override {
|
|
|
|
|
bool WriteWavToFile(const std::string &wavPath) override {
|
|
|
|
|
std::ofstream fout(wavPath, std::ios::binary);
|
|
|
|
|
if (!fout.is_open()) {
|
|
|
|
|
return false;
|
|
|
|
|