rm virtual in TTSArmLinux

pull/3030/head
TianYuan 3 years ago
parent 96038fb01b
commit 0a306db5ff

@ -65,14 +65,14 @@ PredictorInterface::~PredictorInterface() {}
template <typename WavDataType> template <typename WavDataType>
class Predictor : public PredictorInterface { class Predictor : public PredictorInterface {
public: public:
virtual bool Init(const std::string &AcousticModelPath, bool Init(const std::string &AcousticModelPath,
const std::string &VocoderPath, const std::string &VocoderPath,
PowerMode cpuPowerMode, PowerMode cpuPowerMode,
int cpuThreadNum, int cpuThreadNum,
// WAV采样率必须与模型输出匹配 // WAV采样率必须与模型输出匹配
// 如果播放速度和音调异常,请修改采样率 // 如果播放速度和音调异常,请修改采样率
// 常见采样率16000, 24000, 32000, 44100, 48000, 96000 // 常见采样率16000, 24000, 32000, 44100, 48000, 96000
uint32_t wavSampleRate) override { uint32_t wavSampleRate) override {
// Release model if exists // Release model if exists
ReleaseModel(); ReleaseModel();
@ -96,7 +96,7 @@ class Predictor : public PredictorInterface {
ReleaseWav(); ReleaseWav();
} }
virtual std::shared_ptr<PaddlePredictor> LoadModel( std::shared_ptr<PaddlePredictor> LoadModel(
const std::string &modelPath, const std::string &modelPath,
int cpuThreadNum, int cpuThreadNum,
PowerMode cpuPowerMode) override { PowerMode cpuPowerMode) override {
@ -113,12 +113,12 @@ class Predictor : public PredictorInterface {
return CreatePaddlePredictor<MobileConfig>(config); return CreatePaddlePredictor<MobileConfig>(config);
} }
virtual void ReleaseModel() override { void ReleaseModel() override {
acoustic_model_predictor_ = nullptr; acoustic_model_predictor_ = nullptr;
vocoder_predictor_ = nullptr; vocoder_predictor_ = nullptr;
} }
virtual bool RunModel(const std::vector<int64_t> &phones) override { bool RunModel(const std::vector<int64_t> &phones) override {
if (!IsLoaded()) { if (!IsLoaded()) {
return false; return false;
} }
@ -139,7 +139,7 @@ class Predictor : public PredictorInterface {
return true; return true;
} }
virtual std::unique_ptr<const Tensor> GetAcousticModelOutput( std::unique_ptr<const Tensor> GetAcousticModelOutput(
const std::vector<int64_t> &phones) override { const std::vector<int64_t> &phones) override {
auto phones_handle = acoustic_model_predictor_->GetInput(0); auto phones_handle = acoustic_model_predictor_->GetInput(0);
phones_handle->Resize({static_cast<int64_t>(phones.size())}); phones_handle->Resize({static_cast<int64_t>(phones.size())});
@ -159,7 +159,7 @@ class Predictor : public PredictorInterface {
return am_output_handle; return am_output_handle;
} }
virtual std::unique_ptr<const Tensor> GetVocoderOutput( std::unique_ptr<const Tensor> GetVocoderOutput(
std::unique_ptr<const Tensor> &&amOutput) override { std::unique_ptr<const Tensor> &&amOutput) override {
auto mel_handle = vocoder_predictor_->GetInput(0); auto mel_handle = vocoder_predictor_->GetInput(0);
// [?, 80] // [?, 80]
@ -182,7 +182,7 @@ class Predictor : public PredictorInterface {
return voc_output_handle; return voc_output_handle;
} }
virtual void VocoderOutputToWav( void VocoderOutputToWav(
std::unique_ptr<const Tensor> &&vocOutput) override { std::unique_ptr<const Tensor> &&vocOutput) override {
// 获取输出Tensor的数据 // 获取输出Tensor的数据
int64_t output_size = 1; int64_t output_size = 1;
@ -194,35 +194,31 @@ class Predictor : public PredictorInterface {
SaveFloatWav(output_data, output_size); SaveFloatWav(output_data, output_size);
} }
virtual void SaveFloatWav(float *floatWav, int64_t size) override; void SaveFloatWav(float *floatWav, int64_t size) override;
virtual bool IsLoaded() override { bool IsLoaded() override {
return acoustic_model_predictor_ != nullptr && return acoustic_model_predictor_ != nullptr &&
vocoder_predictor_ != nullptr; vocoder_predictor_ != nullptr;
} }
virtual float GetInferenceTime() override { return inference_time_; } float GetInferenceTime() override { return inference_time_; }
const std::vector<WavDataType> &GetWav() { return wav_; } const std::vector<WavDataType> &GetWav() { return wav_; }
virtual int GetWavSize() override { int GetWavSize() override { return wav_.size() * sizeof(WavDataType); }
return wav_.size() * sizeof(WavDataType);
}
// 获取WAV持续时间单位毫秒 // 获取WAV持续时间单位毫秒
virtual float GetWavDuration() override { float GetWavDuration() override {
return static_cast<float>(GetWavSize()) / sizeof(WavDataType) / return static_cast<float>(GetWavSize()) / sizeof(WavDataType) /
static_cast<float>(wav_sample_rate_) * 1000; static_cast<float>(wav_sample_rate_) * 1000;
} }
// 获取RTF合成时间 / 音频时长) // 获取RTF合成时间 / 音频时长)
virtual float GetRTF() override { float GetRTF() override { return GetInferenceTime() / GetWavDuration(); }
return GetInferenceTime() / GetWavDuration();
}
virtual void ReleaseWav() override { wav_.clear(); } void ReleaseWav() override { wav_.clear(); }
virtual bool WriteWavToFile(const std::string &wavPath) override { bool WriteWavToFile(const std::string &wavPath) override {
std::ofstream fout(wavPath, std::ios::binary); std::ofstream fout(wavPath, std::ios::binary);
if (!fout.is_open()) { if (!fout.is_open()) {
return false; return false;

Loading…
Cancel
Save