diff --git a/README.md b/README.md index 3c60db650..f71d0562a 100644 --- a/README.md +++ b/README.md @@ -179,7 +179,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision ### Recent Update - 👑 2023.03.09: Add [Wav2vec2ASR-zh](./examples/aishell/asr3). -- 🎉 2023.03.07: Add [TTS ARM Linux C++ Demo](./demos/TTSArmLinux). +- 🎉 2023.03.07: Add [TTS ARM Linux C++ Demo (with C++ Chinese Text Frontend)](./demos/TTSArmLinux). - 🔥 2023.03.03 Add Voice Conversion [StarGANv2-VC synthesize pipeline](./examples/vctk/vc3). - 🎉 2023.02.16: Add [Cantonese TTS](./examples/canton/tts3). - 🔥 2023.01.10: Add [code-switch asr CLI and Demos](./demos/speech_recognition). diff --git a/README_cn.md b/README_cn.md index 29ee387c0..5771d766b 100644 --- a/README_cn.md +++ b/README_cn.md @@ -184,7 +184,7 @@ ### 近期更新 - 👑 2023.03.09: 新增 [Wav2vec2ASR-zh](./examples/aishell/asr3)。 -- 🎉 2023.03.07: 新增 [TTS ARM Linux C++ 部署示例](./demos/TTSArmLinux)。 +- 🎉 2023.03.07: 新增 [TTS ARM Linux C++ 部署示例 (包含 C++ 中文文本前端模块)](./demos/TTSArmLinux)。 - 🔥 2023.03.03: 新增声音转换模型 [StarGANv2-VC 合成流程](./examples/vctk/vc3)。 - 🎉 2023.02.16: 新增[粤语语音合成](./examples/canton/tts3)。 - 🔥 2023.01.10: 新增[中英混合 ASR CLI 和 Demos](./demos/speech_recognition)。 diff --git a/demos/TTSArmLinux/src/Predictor.hpp b/demos/TTSArmLinux/src/Predictor.hpp index 985d01158..f173abb5c 100644 --- a/demos/TTSArmLinux/src/Predictor.hpp +++ b/demos/TTSArmLinux/src/Predictor.hpp @@ -1,7 +1,20 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include -#include #include +#include #include #include #include @@ -10,24 +23,28 @@ using namespace paddle::lite_api; class PredictorInterface { -public: + public: virtual ~PredictorInterface() = 0; - virtual bool Init( - const std::string &AcousticModelPath, - const std::string &VocoderPath, - PowerMode cpuPowerMode, - int cpuThreadNum, - // WAV采样率(必须与模型输出匹配) - // 如果播放速度和音调异常,请修改采样率 - // 常见采样率:16000, 24000, 32000, 44100, 48000, 96000 - uint32_t wavSampleRate - ) = 0; - virtual std::shared_ptr LoadModel(const std::string &modelPath, int cpuThreadNum, PowerMode cpuPowerMode) = 0; + virtual bool Init(const std::string &AcousticModelPath, + const std::string &VocoderPath, + PowerMode cpuPowerMode, + int cpuThreadNum, + // WAV采样率(必须与模型输出匹配) + // 如果播放速度和音调异常,请修改采样率 + // 常见采样率:16000, 24000, 32000, 44100, 48000, 96000 + uint32_t wavSampleRate) = 0; + virtual std::shared_ptr LoadModel( + const std::string &modelPath, + int cpuThreadNum, + PowerMode cpuPowerMode) = 0; virtual void ReleaseModel() = 0; virtual bool RunModel(const std::vector &phones) = 0; - virtual std::unique_ptr GetAcousticModelOutput(const std::vector &phones) = 0; - virtual std::unique_ptr GetVocoderOutput(std::unique_ptr &&amOutput) = 0; - virtual void VocoderOutputToWav(std::unique_ptr &&vocOutput) = 0; + virtual std::unique_ptr GetAcousticModelOutput( + const std::vector &phones) = 0; + virtual std::unique_ptr GetVocoderOutput( + std::unique_ptr &&amOutput) = 0; + virtual void VocoderOutputToWav( + std::unique_ptr &&vocOutput) = 0; virtual void SaveFloatWav(float *floatWav, int64_t size) = 0; virtual bool IsLoaded() = 0; virtual float GetInferenceTime() = 0; @@ -45,23 +62,22 @@ PredictorInterface::~PredictorInterface() {} // WavDataType: WAV数据类型 // 可在 int16_t 和 float 之间切换, // 用于生成 16-bit PCM 或 32-bit IEEE float 格式的 WAV -template +template class Predictor : public PredictorInterface { -public: - virtual bool Init( - const std::string &AcousticModelPath, - const std::string &VocoderPath, - PowerMode cpuPowerMode, - int cpuThreadNum, - // WAV采样率(必须与模型输出匹配) - // 如果播放速度和音调异常,请修改采样率 - // 常见采样率:16000, 24000, 32000, 44100, 48000, 96000 - uint32_t wavSampleRate - ) override { + public: + bool Init(const std::string &AcousticModelPath, + const std::string &VocoderPath, + PowerMode cpuPowerMode, + int cpuThreadNum, + // WAV采样率(必须与模型输出匹配) + // 如果播放速度和音调异常,请修改采样率 + // 常见采样率:16000, 24000, 32000, 44100, 48000, 96000 + uint32_t wavSampleRate) override { // Release model if exists ReleaseModel(); - acoustic_model_predictor_ = LoadModel(AcousticModelPath, cpuThreadNum, cpuPowerMode); + acoustic_model_predictor_ = + LoadModel(AcousticModelPath, cpuThreadNum, cpuPowerMode); if (acoustic_model_predictor_ == nullptr) { return false; } @@ -80,7 +96,10 @@ public: ReleaseWav(); } - virtual std::shared_ptr LoadModel(const std::string &modelPath, int cpuThreadNum, PowerMode cpuPowerMode) override { + std::shared_ptr LoadModel( + const std::string &modelPath, + int cpuThreadNum, + PowerMode cpuPowerMode) override { if (modelPath.empty()) { return nullptr; } @@ -94,12 +113,12 @@ public: return CreatePaddlePredictor(config); } - virtual void ReleaseModel() override { + void ReleaseModel() override { acoustic_model_predictor_ = nullptr; vocoder_predictor_ = nullptr; } - virtual bool RunModel(const std::vector &phones) override { + bool RunModel(const std::vector &phones) override { if (!IsLoaded()) { return false; } @@ -115,12 +134,13 @@ public: // 计算用时 std::chrono::duration duration = end - start; - inference_time_ = duration.count() * 1000; // 单位:毫秒 + inference_time_ = duration.count() * 1000; // 单位:毫秒 return true; } - virtual std::unique_ptr GetAcousticModelOutput(const std::vector &phones) override { + std::unique_ptr GetAcousticModelOutput( + const std::vector &phones) override { auto phones_handle = acoustic_model_predictor_->GetInput(0); phones_handle->Resize({static_cast(phones.size())}); phones_handle->CopyFromCpu(phones.data()); @@ -139,7 +159,8 @@ public: return am_output_handle; } - virtual std::unique_ptr GetVocoderOutput(std::unique_ptr &&amOutput) override { + std::unique_ptr GetVocoderOutput( + std::unique_ptr &&amOutput) override { auto mel_handle = vocoder_predictor_->GetInput(0); // [?, 80] auto dims = amOutput->shape(); @@ -161,7 +182,8 @@ public: return voc_output_handle; } - virtual void VocoderOutputToWav(std::unique_ptr &&vocOutput) override { + void VocoderOutputToWav( + std::unique_ptr &&vocOutput) override { // 获取输出Tensor的数据 int64_t output_size = 1; for (auto dim : vocOutput->shape()) { @@ -172,39 +194,31 @@ public: SaveFloatWav(output_data, output_size); } - virtual void SaveFloatWav(float *floatWav, int64_t size) override; + void SaveFloatWav(float *floatWav, int64_t size) override; - virtual bool IsLoaded() override { - return acoustic_model_predictor_ != nullptr && vocoder_predictor_ != nullptr; + bool IsLoaded() override { + return acoustic_model_predictor_ != nullptr && + vocoder_predictor_ != nullptr; } - virtual float GetInferenceTime() override { - return inference_time_; - } + float GetInferenceTime() override { return inference_time_; } - const std::vector & GetWav() { - return wav_; - } + const std::vector &GetWav() { return wav_; } - virtual int GetWavSize() override { - return wav_.size() * sizeof(WavDataType); - } + int GetWavSize() override { return wav_.size() * sizeof(WavDataType); } // 获取WAV持续时间(单位:毫秒) - virtual float GetWavDuration() override { - return static_cast(GetWavSize()) / sizeof(WavDataType) / static_cast(wav_sample_rate_) * 1000; + float GetWavDuration() override { + return static_cast(GetWavSize()) / sizeof(WavDataType) / + static_cast(wav_sample_rate_) * 1000; } // 获取RTF(合成时间 / 音频时长) - virtual float GetRTF() override { - return GetInferenceTime() / GetWavDuration(); - } + float GetRTF() override { return GetInferenceTime() / GetWavDuration(); } - virtual void ReleaseWav() override { - wav_.clear(); - } + void ReleaseWav() override { wav_.clear(); } - virtual bool WriteWavToFile(const std::string &wavPath) override { + bool WriteWavToFile(const std::string &wavPath) override { std::ofstream fout(wavPath, std::ios::binary); if (!fout.is_open()) { return false; @@ -216,18 +230,20 @@ public: header.data_size = GetWavSize(); header.size = sizeof(header) - 8 + header.data_size; header.sample_rate = wav_sample_rate_; - header.byte_rate = header.sample_rate * header.num_channels * header.bits_per_sample / 8; + header.byte_rate = header.sample_rate * header.num_channels * + header.bits_per_sample / 8; header.block_align = header.num_channels * header.bits_per_sample / 8; - fout.write(reinterpret_cast(&header), sizeof(header)); + fout.write(reinterpret_cast(&header), sizeof(header)); // 写入wav数据 - fout.write(reinterpret_cast(wav_.data()), header.data_size); + fout.write(reinterpret_cast(wav_.data()), + header.data_size); fout.close(); return true; } -protected: + protected: struct WavHeader { // RIFF 头 char riff[4] = {'R', 'I', 'F', 'F'}; @@ -250,19 +266,17 @@ protected: }; enum WavAudioFormat { - WAV_FORMAT_16BIT_PCM = 1, // 16-bit PCM 格式 + WAV_FORMAT_16BIT_PCM = 1, // 16-bit PCM 格式 WAV_FORMAT_32BIT_FLOAT = 3 // 32-bit IEEE float 格式 }; -protected: + protected: // 返回值通过模板特化由 WavDataType 决定 inline uint16_t GetWavAudioFormat(); - inline float Abs(float number) { - return (number < 0) ? -number : number; - } + inline float Abs(float number) { return (number < 0) ? -number : number; } -protected: + protected: float inference_time_ = 0; uint32_t wav_sample_rate_ = 0; std::vector wav_; @@ -270,36 +284,36 @@ protected: std::shared_ptr vocoder_predictor_ = nullptr; }; -template<> +template <> uint16_t Predictor::GetWavAudioFormat() { return Predictor::WAV_FORMAT_16BIT_PCM; } -template<> +template <> uint16_t Predictor::GetWavAudioFormat() { return Predictor::WAV_FORMAT_32BIT_FLOAT; } // 保存 16-bit PCM 格式 WAV -template<> +template <> void Predictor::SaveFloatWav(float *floatWav, int64_t size) { wav_.resize(size); float maxSample = 0.01; // 寻找最大采样值 - for (int64_t i=0; i maxSample) { maxSample = sample; } } // 把采样值缩放到 int_16 范围 - for (int64_t i=0; i +template <> void Predictor::SaveFloatWav(float *floatWav, int64_t size) { wav_.resize(size); std::copy_n(floatWav, size, wav_.data()); diff --git a/demos/TTSArmLinux/src/main.cc b/demos/TTSArmLinux/src/main.cc index f3bd0f7b0..0b8e26bc4 100644 --- a/demos/TTSArmLinux/src/main.cc +++ b/demos/TTSArmLinux/src/main.cc @@ -1,23 +1,48 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include #include #include +#include #include #include -#include -#include -#include -#include -#include #include "Predictor.hpp" using namespace paddle::lite_api; -DEFINE_string(sentence, "你好,欢迎使用语音合成服务", "Text to be synthesized (Chinese only. English will crash the program.)"); +DEFINE_string( + sentence, + "你好,欢迎使用语音合成服务", + "Text to be synthesized (Chinese only. English will crash the program.)"); DEFINE_string(front_conf, "./front.conf", "Front configuration file"); -DEFINE_string(acoustic_model, "./models/cpu/fastspeech2_csmsc_arm.nb", "Acoustic model .nb file"); -DEFINE_string(vocoder, "./models/cpu/fastspeech2_csmsc_arm.nb", "vocoder .nb file"); +DEFINE_string(acoustic_model, + "./models/cpu/fastspeech2_csmsc_arm.nb", + "Acoustic model .nb file"); +DEFINE_string(vocoder, + "./models/cpu/fastspeech2_csmsc_arm.nb", + "vocoder .nb file"); DEFINE_string(output_wav, "./output/tts.wav", "Output WAV file"); -DEFINE_string(wav_bit_depth, "16", "WAV bit depth, 16 (16-bit PCM) or 32 (32-bit IEEE float)"); -DEFINE_string(wav_sample_rate, "24000", "WAV sample rate, should match the output of the vocoder"); +DEFINE_string(wav_bit_depth, + "16", + "WAV bit depth, 16 (16-bit PCM) or 32 (32-bit IEEE float)"); +DEFINE_string(wav_sample_rate, + "24000", + "WAV sample rate, should match the output of the vocoder"); DEFINE_string(cpu_thread, "1", "CPU thread numbers"); int main(int argc, char *argv[]) { @@ -53,7 +78,7 @@ int main(int argc, char *argv[]) { // 繁体转简体 std::wstring sentence_simp; - front_inst->Trand2Simp(ws_sentence, sentence_simp); + front_inst->Trand2Simp(ws_sentence, &sentence_simp); ws_sentence = sentence_simp; std::string s_sentence; @@ -63,28 +88,30 @@ int main(int argc, char *argv[]) { // 根据标点进行分句 LOG(INFO) << "Start to segment sentences by punctuation"; - front_inst->SplitByPunc(ws_sentence, sentence_part); + front_inst->SplitByPunc(ws_sentence, &sentence_part); LOG(INFO) << "Segment sentences through punctuation successfully"; // 分句后获取音素id - LOG(INFO) << "Start to get the phoneme and tone id sequence of each sentence"; - for(int i = 0; i < sentence_part.size(); i++) { - - LOG(INFO) << "Raw sentence is: " << ppspeech::wstring2utf8string(sentence_part[i]); - front_inst->SentenceNormalize(sentence_part[i]); + LOG(INFO) + << "Start to get the phoneme and tone id sequence of each sentence"; + for (int i = 0; i < sentence_part.size(); i++) { + LOG(INFO) << "Raw sentence is: " + << ppspeech::wstring2utf8string(sentence_part[i]); + front_inst->SentenceNormalize(&sentence_part[i]); s_sentence = ppspeech::wstring2utf8string(sentence_part[i]); LOG(INFO) << "After normalization sentence is: " << s_sentence; - - if (0 != front_inst->GetSentenceIds(s_sentence, phoneids, toneids)) { + + if (0 != front_inst->GetSentenceIds(s_sentence, &phoneids, &toneids)) { LOG(ERROR) << "TTS inst get sentence phoneids and toneids failed"; return -1; } - } - LOG(INFO) << "The phoneids of the sentence is: " << limonp::Join(phoneids.begin(), phoneids.end(), " "); - LOG(INFO) << "The toneids of the sentence is: " << limonp::Join(toneids.begin(), toneids.end(), " "); + LOG(INFO) << "The phoneids of the sentence is: " + << limonp::Join(phoneids.begin(), phoneids.end(), " "); + LOG(INFO) << "The toneids of the sentence is: " + << limonp::Join(toneids.begin(), toneids.end(), " "); LOG(INFO) << "Get the phoneme id sequence of each sentence successfully"; - + /////////////////////////// 后端:音素转音频 /////////////////////////// @@ -99,13 +126,19 @@ int main(int argc, char *argv[]) { // CPU电源模式 const PowerMode cpuPowerMode = PowerMode::LITE_POWER_HIGH; - if (!predictor->Init(FLAGS_acoustic_model, FLAGS_vocoder, cpuPowerMode, cpuThreadNum, wavSampleRate)) { + if (!predictor->Init(FLAGS_acoustic_model, + FLAGS_vocoder, + cpuPowerMode, + cpuThreadNum, + wavSampleRate)) { LOG(ERROR) << "predictor init failed" << std::endl; return -1; } std::vector phones(phoneids.size()); - std::transform(phoneids.begin(), phoneids.end(), phones.begin(), [](int x) { return static_cast(x); }); + std::transform(phoneids.begin(), phoneids.end(), phones.begin(), [](int x) { + return static_cast(x); + }); if (!predictor->RunModel(phones)) { LOG(ERROR) << "predictor run model failed" << std::endl; @@ -113,7 +146,8 @@ int main(int argc, char *argv[]) { } LOG(INFO) << "Inference time: " << predictor->GetInferenceTime() << " ms, " - << "WAV size (without header): " << predictor->GetWavSize() << " bytes, " + << "WAV size (without header): " << predictor->GetWavSize() + << " bytes, " << "WAV duration: " << predictor->GetWavDuration() << " ms, " << "RTF: " << predictor->GetRTF() << std::endl; diff --git a/demos/TTSCppFrontend/README.md b/demos/TTSCppFrontend/README.md index 592140ae1..552858de3 100644 --- a/demos/TTSCppFrontend/README.md +++ b/demos/TTSCppFrontend/README.md @@ -38,6 +38,7 @@ If the download speed is too slow, you can open [third-party/CMakeLists.txt](thi ``` ## Run +You can change `--phone2id_path` in `./front_demo/front.conf` to the `phone_id_map.txt` of your own acoustic model. ``` ./run_front_demo.sh diff --git a/demos/TTSCppFrontend/front_demo/front_demo.cpp b/demos/TTSCppFrontend/front_demo/front_demo.cpp index e943fd6f7..19f16758b 100644 --- a/demos/TTSCppFrontend/front_demo/front_demo.cpp +++ b/demos/TTSCppFrontend/front_demo/front_demo.cpp @@ -1,19 +1,32 @@ -#include -//#include "utils/dir_utils.h" -#include "front/front_interface.h" -#include +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include +#include #include +#include +#include "front/front_interface.h" DEFINE_string(sentence, "你好,欢迎使用语音合成服务", "Text to be synthesized"); DEFINE_string(front_conf, "./front_demo/front.conf", "Front conf file"); -//DEFINE_string(seperate_tone, "true", "If true, get phoneids and tonesid"); +// DEFINE_string(seperate_tone, "true", "If true, get phoneids and tonesid"); int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); // 实例化文本前端引擎 - ppspeech::FrontEngineInterface *front_inst = nullptr; + ppspeech::FrontEngineInterface* front_inst = nullptr; front_inst = new ppspeech::FrontEngineInterface(FLAGS_front_conf); if ((!front_inst) || (front_inst->init())) { LOG(ERROR) << "Creater tts engine failed!"; @@ -28,7 +41,7 @@ int main(int argc, char** argv) { // 繁体转简体 std::wstring sentence_simp; - front_inst->Trand2Simp(ws_sentence, sentence_simp); + front_inst->Trand2Simp(ws_sentence, &sentence_simp); ws_sentence = sentence_simp; std::string s_sentence; @@ -38,28 +51,29 @@ int main(int argc, char** argv) { // 根据标点进行分句 LOG(INFO) << "Start to segment sentences by punctuation"; - front_inst->SplitByPunc(ws_sentence, sentence_part); + front_inst->SplitByPunc(ws_sentence, &sentence_part); LOG(INFO) << "Segment sentences through punctuation successfully"; // 分句后获取音素id - LOG(INFO) << "Start to get the phoneme and tone id sequence of each sentence"; - for(int i = 0; i < sentence_part.size(); i++) { - - LOG(INFO) << "Raw sentence is: " << ppspeech::wstring2utf8string(sentence_part[i]); - front_inst->SentenceNormalize(sentence_part[i]); + LOG(INFO) + << "Start to get the phoneme and tone id sequence of each sentence"; + for (int i = 0; i < sentence_part.size(); i++) { + LOG(INFO) << "Raw sentence is: " + << ppspeech::wstring2utf8string(sentence_part[i]); + front_inst->SentenceNormalize(&sentence_part[i]); s_sentence = ppspeech::wstring2utf8string(sentence_part[i]); LOG(INFO) << "After normalization sentence is: " << s_sentence; - - if (0 != front_inst->GetSentenceIds(s_sentence, phoneids, toneids)) { + + if (0 != front_inst->GetSentenceIds(s_sentence, &phoneids, &toneids)) { LOG(ERROR) << "TTS inst get sentence phoneids and toneids failed"; return -1; } - } - LOG(INFO) << "The phoneids of the sentence is: " << limonp::Join(phoneids.begin(), phoneids.end(), " "); - LOG(INFO) << "The toneids of the sentence is: " << limonp::Join(toneids.begin(), toneids.end(), " "); + LOG(INFO) << "The phoneids of the sentence is: " + << limonp::Join(phoneids.begin(), phoneids.end(), " "); + LOG(INFO) << "The toneids of the sentence is: " + << limonp::Join(toneids.begin(), toneids.end(), " "); LOG(INFO) << "Get the phoneme id sequence of each sentence successfully"; - + return EXIT_SUCCESS; } - diff --git a/demos/TTSCppFrontend/front_demo/gentools/gen_dict_paddlespeech.py b/demos/TTSCppFrontend/front_demo/gentools/gen_dict_paddlespeech.py index e9a2c96f6..5aaa6e345 100644 --- a/demos/TTSCppFrontend/front_demo/gentools/gen_dict_paddlespeech.py +++ b/demos/TTSCppFrontend/front_demo/gentools/gen_dict_paddlespeech.py @@ -1,19 +1,28 @@ -# !/usr/bin/env python3 -# -*- coding: utf-8 -*- -######################################################################## +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # -# Copyright 2021 liangyunming(liangyunming@baidu.com) +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# Execute the script when PaddleSpeech has been installed -# PaddleSpeech: https://github.com/PaddlePaddle/PaddleSpeech - -######################################################################## - +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse import configparser + from paddlespeech.t2s.frontend.zh_frontend import Frontend -def get_phone(frontend, word, merge_sentences=True, print_info=False, robot=False, get_tone_ids=False): + +def get_phone(frontend, + word, + merge_sentences=True, + print_info=False, + robot=False, + get_tone_ids=False): phonemes = frontend.get_phonemes(word, merge_sentences, print_info, robot) # Some optimizations phones, tones = frontend._get_phone_tone(phonemes[0], get_tone_ids) @@ -22,7 +31,10 @@ def get_phone(frontend, word, merge_sentences=True, print_info=False, robot=Fals return phones, tones -def gen_word2phone_dict(frontend, jieba_words_dict, word2phone_dict, get_tone=False): +def gen_word2phone_dict(frontend, + jieba_words_dict, + word2phone_dict, + get_tone=False): with open(jieba_words_dict, "r") as f1, open(word2phone_dict, "w+") as f2: for line in f1.readlines(): word = line.split(" ")[0] @@ -30,9 +42,9 @@ def gen_word2phone_dict(frontend, jieba_words_dict, word2phone_dict, get_tone=Fa phone_str = "" if tone: - assert(len(phone) == len(tone)) + assert (len(phone) == len(tone)) for i in range(len(tone)): - phone_tone = phone[i] + tone[i] + phone_tone = phone[i] + tone[i] phone_str += (" " + phone_tone) phone_str = phone_str.strip("sp0").strip(" ") else: @@ -45,43 +57,55 @@ def gen_word2phone_dict(frontend, jieba_words_dict, word2phone_dict, get_tone=Fa def main(): - parser = argparse.ArgumentParser( - description="Generate dictionary") + parser = argparse.ArgumentParser(description="Generate dictionary") parser.add_argument( "--config", type=str, default="./config.ini", help="config file.") parser.add_argument( - "--am_type", type=str, default="fastspeech2", help="fastspeech2 or speedyspeech") + "--am_type", + type=str, + default="fastspeech2", + help="fastspeech2 or speedyspeech") args = parser.parse_args() # Read config cf = configparser.ConfigParser() cf.read(args.config) - jieba_words_dict_file = cf.get("jieba", "jieba_words_dict") # get words dict + jieba_words_dict_file = cf.get("jieba", + "jieba_words_dict") # get words dict am_type = args.am_type - if(am_type == "fastspeech2"): + if (am_type == "fastspeech2"): phone2id_dict_file = cf.get(am_type, "phone2id_dict") word2phone_dict_file = cf.get(am_type, "word2phone_dict") frontend = Frontend(phone_vocab_path=phone2id_dict_file) print("frontend done!") - gen_word2phone_dict(frontend, jieba_words_dict_file, word2phone_dict_file, get_tone=False) - - elif(am_type == "speedyspeech"): + gen_word2phone_dict( + frontend, + jieba_words_dict_file, + word2phone_dict_file, + get_tone=False) + + elif (am_type == "speedyspeech"): phone2id_dict_file = cf.get(am_type, "phone2id_dict") tone2id_dict_file = cf.get(am_type, "tone2id_dict") word2phone_dict_file = cf.get(am_type, "word2phone_dict") - frontend = Frontend(phone_vocab_path=phone2id_dict_file, tone_vocab_path=tone2id_dict_file) + frontend = Frontend( + phone_vocab_path=phone2id_dict_file, + tone_vocab_path=tone2id_dict_file) print("frontend done!") - gen_word2phone_dict(frontend, jieba_words_dict_file, word2phone_dict_file, get_tone=True) - + gen_word2phone_dict( + frontend, + jieba_words_dict_file, + word2phone_dict_file, + get_tone=True) else: print("Please set correct am type, fastspeech2 or speedyspeech.") - - + + if __name__ == "__main__": main() diff --git a/demos/TTSCppFrontend/front_demo/gentools/genid.py b/demos/TTSCppFrontend/front_demo/gentools/genid.py index e2866bb0e..cf83623f0 100644 --- a/demos/TTSCppFrontend/front_demo/gentools/genid.py +++ b/demos/TTSCppFrontend/front_demo/gentools/genid.py @@ -1,10 +1,23 @@ -#from parakeet.frontend.vocab import Vocab +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. PHONESFILE = "./dict/phones.txt" PHONES_ID_FILE = "./dict/phonesid.dict" TONESFILE = "./dict/tones.txt" TONES_ID_FILE = "./dict/tonesid.dict" + def GenIdFile(file, idfile): id = 2 with open(file, 'r') as f1, open(idfile, "w+") as f2: @@ -16,7 +29,7 @@ def GenIdFile(file, idfile): f2.write(phone + " " + str(id) + "\n") id += 1 + if __name__ == "__main__": GenIdFile(PHONESFILE, PHONES_ID_FILE) GenIdFile(TONESFILE, TONES_ID_FILE) - diff --git a/demos/TTSCppFrontend/front_demo/gentools/word2phones.py b/demos/TTSCppFrontend/front_demo/gentools/word2phones.py index 6a1822023..8726ee89c 100644 --- a/demos/TTSCppFrontend/front_demo/gentools/word2phones.py +++ b/demos/TTSCppFrontend/front_demo/gentools/word2phones.py @@ -1,9 +1,25 @@ -from pypinyin import lazy_pinyin, Style +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import re +from pypinyin import lazy_pinyin +from pypinyin import Style + worddict = "./dict/jieba_part.dict.utf8" newdict = "./dict/word_phones.dict" + def GenPhones(initials, finals, seperate=True): phones = [] @@ -14,9 +30,9 @@ def GenPhones(initials, finals, seperate=True): elif c in ['zh', 'ch', 'sh', 'r']: v = re.sub('i', 'iii', v) if c: - if seperate == True: + if seperate is True: phones.append(c + '0') - elif seperate == False: + elif seperate is False: phones.append(c) else: print("Not sure whether phone and tone need to be separated") @@ -28,8 +44,10 @@ def GenPhones(initials, finals, seperate=True): with open(worddict, "r") as f1, open(newdict, "w+") as f2: for line in f1.readlines(): word = line.split(" ")[0] - initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS) - finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) + initials = lazy_pinyin( + word, neutral_tone_with_five=True, style=Style.INITIALS) + finals = lazy_pinyin( + word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) phones = GenPhones(initials, finals, True) diff --git a/demos/TTSCppFrontend/src/base/type_conv.cpp b/demos/TTSCppFrontend/src/base/type_conv.cpp index 5d5de43c5..b7ff63642 100644 --- a/demos/TTSCppFrontend/src/base/type_conv.cpp +++ b/demos/TTSCppFrontend/src/base/type_conv.cpp @@ -1,18 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "base/type_conv.h" namespace ppspeech { // wstring to string -std::string wstring2utf8string(const std::wstring& str) -{ - static std::wstring_convert > strCnv; +std::string wstring2utf8string(const std::wstring& str) { + static std::wstring_convert> strCnv; return strCnv.to_bytes(str); } - -// string to wstring -std::wstring utf8string2wstring(const std::string& str) -{ - static std::wstring_convert< std::codecvt_utf8 > strCnv; - return strCnv.from_bytes(str); -} +// string to wstring +std::wstring utf8string2wstring(const std::string& str) { + static std::wstring_convert> strCnv; + return strCnv.from_bytes(str); } +} // namespace ppspeech diff --git a/demos/TTSCppFrontend/src/base/type_conv.h b/demos/TTSCppFrontend/src/base/type_conv.h index 9acb7a6d2..6aecfc438 100644 --- a/demos/TTSCppFrontend/src/base/type_conv.h +++ b/demos/TTSCppFrontend/src/base/type_conv.h @@ -1,18 +1,31 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef BASE_TYPE_CONVC_H #define BASE_TYPE_CONVC_H -#include -#include #include +#include +#include namespace ppspeech { // wstring to string std::string wstring2utf8string(const std::wstring& str); - -// string to wstring -std::wstring utf8string2wstring(const std::string& str); +// string to wstring +std::wstring utf8string2wstring(const std::string& str); } #endif // BASE_TYPE_CONVC_H \ No newline at end of file diff --git a/demos/TTSCppFrontend/src/front/front_interface.cpp b/demos/TTSCppFrontend/src/front/front_interface.cpp index 5b828ac1b..8bd466d28 100644 --- a/demos/TTSCppFrontend/src/front/front_interface.cpp +++ b/demos/TTSCppFrontend/src/front/front_interface.cpp @@ -1,3 +1,16 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "front/front_interface.h" namespace ppspeech { @@ -5,96 +18,123 @@ namespace ppspeech { int FrontEngineInterface::init() { if (_initialed) { return 0; - } + } if (0 != ReadConfFile()) { LOG(ERROR) << "Read front conf file failed"; return -1; } - _jieba = new cppjieba::Jieba(_jieba_dict_path, _jieba_hmm_path, _jieba_user_dict_path, - _jieba_idf_path, _jieba_stop_word_path); - - _punc = {",", "。", "、", "?", ":", ";", "~", "!", - ",", ".", "?", "!", ":", ";", "/", "\\"}; - _punc_omit = {"“", "”", "\"", "\""}; + _jieba = new cppjieba::Jieba(_jieba_dict_path, + _jieba_hmm_path, + _jieba_user_dict_path, + _jieba_idf_path, + _jieba_stop_word_path); + + _punc = {",", + "。", + "、", + "?", + ":", + ";", + "~", + "!", + ",", + ".", + "?", + "!", + ":", + ";", + "/", + "\\"}; + _punc_omit = {"“", "”", "\"", "\""}; // 需要儿化音处理的词语 - must_erhua = {"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿"}; - not_erhua = { - "虐儿", "为儿", "护儿", "瞒儿", "救儿", "替儿", "有儿", "一儿", "我儿", "俺儿", "妻儿", - "拐儿", "聋儿", "乞儿", "患儿", "幼儿", "孤儿", "婴儿", "婴幼儿", "连体儿", "脑瘫儿", - "流浪儿", "体弱儿", "混血儿", "蜜雪儿", "舫儿", "祖儿", "美儿", "应采儿", "可儿", "侄儿", - "孙儿", "侄孙儿", "女儿", "男儿", "红孩儿", "花儿", "虫儿", "马儿", "鸟儿", "猪儿", "猫儿", - "狗儿" - }; - - must_not_neural_tone_words = {"男子", "女子", "分子", "原子", "量子", "莲子", "石子", "瓜子", "电子"}; + must_erhua = { + "小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿"}; + not_erhua = {"虐儿", "为儿", "护儿", "瞒儿", "救儿", "替儿", + "有儿", "一儿", "我儿", "俺儿", "妻儿", "拐儿", + "聋儿", "乞儿", "患儿", "幼儿", "孤儿", "婴儿", + "婴幼儿", "连体儿", "脑瘫儿", "流浪儿", "体弱儿", "混血儿", + "蜜雪儿", "舫儿", "祖儿", "美儿", "应采儿", "可儿", + "侄儿", "孙儿", "侄孙儿", "女儿", "男儿", "红孩儿", + "花儿", "虫儿", "马儿", "鸟儿", "猪儿", "猫儿", + "狗儿"}; + + must_not_neural_tone_words = { + "男子", "女子", "分子", "原子", "量子", "莲子", "石子", "瓜子", "电子"}; // 需要轻声处理的词语 must_neural_tone_words = { - "麻烦", "麻利", "鸳鸯", "高粱", "骨头", "骆驼", "马虎", "首饰", "馒头", "馄饨", "风筝", - "难为", "队伍", "阔气", "闺女", "门道", "锄头", "铺盖", "铃铛", "铁匠", "钥匙", "里脊", - "里头", "部分", "那么", "道士", "造化", "迷糊", "连累", "这么", "这个", "运气", "过去", - "软和", "转悠", "踏实", "跳蚤", "跟头", "趔趄", "财主", "豆腐", "讲究", "记性", "记号", - "认识", "规矩", "见识", "裁缝", "补丁", "衣裳", "衣服", "衙门", "街坊", "行李", "行当", - "蛤蟆", "蘑菇", "薄荷", "葫芦", "葡萄", "萝卜", "荸荠", "苗条", "苗头", "苍蝇", "芝麻", - "舒服", "舒坦", "舌头", "自在", "膏药", "脾气", "脑袋", "脊梁", "能耐", "胳膊", "胭脂", - "胡萝", "胡琴", "胡同", "聪明", "耽误", "耽搁", "耷拉", "耳朵", "老爷", "老实", "老婆", - "老头", "老太", "翻腾", "罗嗦", "罐头", "编辑", "结实", "红火", "累赘", "糨糊", "糊涂", - "精神", "粮食", "簸箕", "篱笆", "算计", "算盘", "答应", "笤帚", "笑语", "笑话", "窟窿", - "窝囊", "窗户", "稳当", "稀罕", "称呼", "秧歌", "秀气", "秀才", "福气", "祖宗", "砚台", - "码头", "石榴", "石头", "石匠", "知识", "眼睛", "眯缝", "眨巴", "眉毛", "相声", "盘算", - "白净", "痢疾", "痛快", "疟疾", "疙瘩", "疏忽", "畜生", "生意", "甘蔗", "琵琶", "琢磨", - "琉璃", "玻璃", "玫瑰", "玄乎", "狐狸", "状元", "特务", "牲口", "牙碜", "牌楼", "爽快", - "爱人", "热闹", "烧饼", "烟筒", "烂糊", "点心", "炊帚", "灯笼", "火候", "漂亮", "滑溜", - "溜达", "温和", "清楚", "消息", "浪头", "活泼", "比方", "正经", "欺负", "模糊", "槟榔", - "棺材", "棒槌", "棉花", "核桃", "栅栏", "柴火", "架势", "枕头", "枇杷", "机灵", "本事", - "木头", "木匠", "朋友", "月饼", "月亮", "暖和", "明白", "时候", "新鲜", "故事", "收拾", - "收成", "提防", "挖苦", "挑剔", "指甲", "指头", "拾掇", "拳头", "拨弄", "招牌", "招呼", - "抬举", "护士", "折腾", "扫帚", "打量", "打算", "打点", "打扮", "打听", "打发", "扎实", - "扁担", "戒指", "懒得", "意识", "意思", "情形", "悟性", "怪物", "思量", "怎么", "念头", - "念叨", "快活", "忙活", "志气", "心思", "得罪", "张罗", "弟兄", "开通", "应酬", "庄稼", - "干事", "帮手", "帐篷", "希罕", "师父", "师傅", "巴结", "巴掌", "差事", "工夫", "岁数", - "屁股", "尾巴", "少爷", "小气", "小伙", "将就", "对头", "对付", "寡妇", "家伙", "客气", - "实在", "官司", "学问", "学生", "字号", "嫁妆", "媳妇", "媒人", "婆家", "娘家", "委屈", - "姑娘", "姐夫", "妯娌", "妥当", "妖精", "奴才", "女婿", "头发", "太阳", "大爷", "大方", - "大意", "大夫", "多少", "多么", "外甥", "壮实", "地道", "地方", "在乎", "困难", "嘴巴", - "嘱咐", "嘟囔", "嘀咕", "喜欢", "喇嘛", "喇叭", "商量", "唾沫", "哑巴", "哈欠", "哆嗦", - "咳嗽", "和尚", "告诉", "告示", "含糊", "吓唬", "后头", "名字", "名堂", "合同", "吆喝", - "叫唤", "口袋", "厚道", "厉害", "千斤", "包袱", "包涵", "匀称", "勤快", "动静", "动弹", - "功夫", "力气", "前头", "刺猬", "刺激", "别扭", "利落", "利索", "利害", "分析", "出息", - "凑合", "凉快", "冷战", "冤枉", "冒失", "养活", "关系", "先生", "兄弟", "便宜", "使唤", - "佩服", "作坊", "体面", "位置", "似的", "伙计", "休息", "什么", "人家", "亲戚", "亲家", - "交情", "云彩", "事情", "买卖", "主意", "丫头", "丧气", "两口", "东西", "东家", "世故", - "不由", "不在", "下水", "下巴", "上头", "上司", "丈夫", "丈人", "一辈", "那个", "菩萨", - "父亲", "母亲", "咕噜", "邋遢", "费用", "冤家", "甜头", "介绍", "荒唐", "大人", "泥鳅", - "幸福", "熟悉", "计划", "扑腾", "蜡烛", "姥爷", "照顾", "喉咙", "吉他", "弄堂", "蚂蚱", - "凤凰", "拖沓", "寒碜", "糟蹋", "倒腾", "报复", "逻辑", "盘缠", "喽啰", "牢骚", "咖喱", - "扫把", "惦记" - }; - - + "麻烦", "麻利", "鸳鸯", "高粱", "骨头", "骆驼", "马虎", "首饰", "馒头", + "馄饨", "风筝", "难为", "队伍", "阔气", "闺女", "门道", "锄头", "铺盖", + "铃铛", "铁匠", "钥匙", "里脊", "里头", "部分", "那么", "道士", "造化", + "迷糊", "连累", "这么", "这个", "运气", "过去", "软和", "转悠", "踏实", + "跳蚤", "跟头", "趔趄", "财主", "豆腐", "讲究", "记性", "记号", "认识", + "规矩", "见识", "裁缝", "补丁", "衣裳", "衣服", "衙门", "街坊", "行李", + "行当", "蛤蟆", "蘑菇", "薄荷", "葫芦", "葡萄", "萝卜", "荸荠", "苗条", + "苗头", "苍蝇", "芝麻", "舒服", "舒坦", "舌头", "自在", "膏药", "脾气", + "脑袋", "脊梁", "能耐", "胳膊", "胭脂", "胡萝", "胡琴", "胡同", "聪明", + "耽误", "耽搁", "耷拉", "耳朵", "老爷", "老实", "老婆", "老头", "老太", + "翻腾", "罗嗦", "罐头", "编辑", "结实", "红火", "累赘", "糨糊", "糊涂", + "精神", "粮食", "簸箕", "篱笆", "算计", "算盘", "答应", "笤帚", "笑语", + "笑话", "窟窿", "窝囊", "窗户", "稳当", "稀罕", "称呼", "秧歌", "秀气", + "秀才", "福气", "祖宗", "砚台", "码头", "石榴", "石头", "石匠", "知识", + "眼睛", "眯缝", "眨巴", "眉毛", "相声", "盘算", "白净", "痢疾", "痛快", + "疟疾", "疙瘩", "疏忽", "畜生", "生意", "甘蔗", "琵琶", "琢磨", "琉璃", + "玻璃", "玫瑰", "玄乎", "狐狸", "状元", "特务", "牲口", "牙碜", "牌楼", + "爽快", "爱人", "热闹", "烧饼", "烟筒", "烂糊", "点心", "炊帚", "灯笼", + "火候", "漂亮", "滑溜", "溜达", "温和", "清楚", "消息", "浪头", "活泼", + "比方", "正经", "欺负", "模糊", "槟榔", "棺材", "棒槌", "棉花", "核桃", + "栅栏", "柴火", "架势", "枕头", "枇杷", "机灵", "本事", "木头", "木匠", + "朋友", "月饼", "月亮", "暖和", "明白", "时候", "新鲜", "故事", "收拾", + "收成", "提防", "挖苦", "挑剔", "指甲", "指头", "拾掇", "拳头", "拨弄", + "招牌", "招呼", "抬举", "护士", "折腾", "扫帚", "打量", "打算", "打点", + "打扮", "打听", "打发", "扎实", "扁担", "戒指", "懒得", "意识", "意思", + "情形", "悟性", "怪物", "思量", "怎么", "念头", "念叨", "快活", "忙活", + "志气", "心思", "得罪", "张罗", "弟兄", "开通", "应酬", "庄稼", "干事", + "帮手", "帐篷", "希罕", "师父", "师傅", "巴结", "巴掌", "差事", "工夫", + "岁数", "屁股", "尾巴", "少爷", "小气", "小伙", "将就", "对头", "对付", + "寡妇", "家伙", "客气", "实在", "官司", "学问", "学生", "字号", "嫁妆", + "媳妇", "媒人", "婆家", "娘家", "委屈", "姑娘", "姐夫", "妯娌", "妥当", + "妖精", "奴才", "女婿", "头发", "太阳", "大爷", "大方", "大意", "大夫", + "多少", "多么", "外甥", "壮实", "地道", "地方", "在乎", "困难", "嘴巴", + "嘱咐", "嘟囔", "嘀咕", "喜欢", "喇嘛", "喇叭", "商量", "唾沫", "哑巴", + "哈欠", "哆嗦", "咳嗽", "和尚", "告诉", "告示", "含糊", "吓唬", "后头", + "名字", "名堂", "合同", "吆喝", "叫唤", "口袋", "厚道", "厉害", "千斤", + "包袱", "包涵", "匀称", "勤快", "动静", "动弹", "功夫", "力气", "前头", + "刺猬", "刺激", "别扭", "利落", "利索", "利害", "分析", "出息", "凑合", + "凉快", "冷战", "冤枉", "冒失", "养活", "关系", "先生", "兄弟", "便宜", + "使唤", "佩服", "作坊", "体面", "位置", "似的", "伙计", "休息", "什么", + "人家", "亲戚", "亲家", "交情", "云彩", "事情", "买卖", "主意", "丫头", + "丧气", "两口", "东西", "东家", "世故", "不由", "不在", "下水", "下巴", + "上头", "上司", "丈夫", "丈人", "一辈", "那个", "菩萨", "父亲", "母亲", + "咕噜", "邋遢", "费用", "冤家", "甜头", "介绍", "荒唐", "大人", "泥鳅", + "幸福", "熟悉", "计划", "扑腾", "蜡烛", "姥爷", "照顾", "喉咙", "吉他", + "弄堂", "蚂蚱", "凤凰", "拖沓", "寒碜", "糟蹋", "倒腾", "报复", "逻辑", + "盘缠", "喽啰", "牢骚", "咖喱", "扫把", "惦记"}; + + // 生成词典(词到音素的映射) - if (0 != GenDict(_word2phone_path, word_phone_map)) { + if (0 != GenDict(_word2phone_path, &word_phone_map)) { LOG(ERROR) << "Genarate word2phone dict failed"; return -1; } // 生成音素字典(音素到音素id的映射) - if (0 != GenDict(_phone2id_path, phone_id_map)) { + if (0 != GenDict(_phone2id_path, &phone_id_map)) { LOG(ERROR) << "Genarate phone2id dict failed"; return -1; } // 生成音调字典(音调到音调id的映射) if (_seperate_tone == "true") { - if (0 != GenDict(_tone2id_path, tone_id_map)) { + if (0 != GenDict(_tone2id_path, &tone_id_map)) { LOG(ERROR) << "Genarate tone2id dict failed"; return -1; - } + } } // 生成繁简字典(繁体到简体id的映射) - if (0 != GenDict(_trand2simp_path, trand_simp_map)) { + if (0 != GenDict(_trand2simp_path, &trand_simp_map)) { LOG(ERROR) << "Genarate trand2simp dict failed"; return -1; } @@ -113,14 +153,14 @@ int FrontEngineInterface::ReadConfFile() { while (std::getline(is, line)) { if (line.substr(0, 2) == "--") { size_t pos = line.find_first_of("=", 0); - std::string key = line.substr(2, pos-2); + std::string key = line.substr(2, pos - 2); std::string value = line.substr(pos + 1); conf_map[key] = value; LOG(INFO) << "Key: " << key << "; Value: " << value; } } - // jieba conf path + // jieba conf path _jieba_dict_path = conf_map["jieba_dict_path"]; _jieba_hmm_path = conf_map["jieba_hmm_path"]; _jieba_user_dict_path = conf_map["jieba_user_dict_path"]; @@ -137,23 +177,26 @@ int FrontEngineInterface::ReadConfFile() { return 0; } -int FrontEngineInterface::Trand2Simp(const std::wstring &sentence, std::wstring &sentence_simp) { - //sentence_simp = sentence; - for(int i = 0; i < sentence.length(); i++) { +int FrontEngineInterface::Trand2Simp(const std::wstring &sentence, + std::wstring *sentence_simp) { + // sentence_simp = sentence; + for (int i = 0; i < sentence.length(); i++) { std::wstring temp(1, sentence[i]); std::string sigle_word = ppspeech::wstring2utf8string(temp); // 单个字是否在繁转简的字典里 - if(trand_simp_map.find(sigle_word) == trand_simp_map.end()) { - sentence_simp += temp; + if (trand_simp_map.find(sigle_word) == trand_simp_map.end()) { + sentence_simp->append(temp); } else { - sentence_simp += (ppspeech::utf8string2wstring(trand_simp_map[sigle_word])); + sentence_simp->append( + (ppspeech::utf8string2wstring(trand_simp_map[sigle_word]))); } } return 0; } -int FrontEngineInterface::GenDict(const std::string &dict_file, std::map &map) { +int FrontEngineInterface::GenDict(const std::string &dict_file, + std::map *map) { std::ifstream is(dict_file.c_str(), std::ifstream::in); if (!is.good()) { LOG(ERROR) << "Cannot open dict file: " << dict_file; @@ -163,28 +206,32 @@ int FrontEngineInterface::GenDict(const std::string &dict_file, std::map> &seg, - std::vector &seg_words) { - std::vector> ::iterator iter; - for(iter=seg.begin(); iter!=seg.end(); iter++) { - seg_words.push_back((*iter).first); +int FrontEngineInterface::GetSegResult( + std::vector> *seg, + std::vector *seg_words) { + std::vector>::iterator iter; + for (iter = seg->begin(); iter != seg->end(); iter++) { + seg_words->push_back((*iter).first); } return 0; } -int FrontEngineInterface::GetSentenceIds(const std::string &sentence, std::vector &phoneids, std::vector &toneids) { - std::vector> cut_result; //分词结果包含词和词性 - if (0 != Cut(sentence, cut_result)) { +int FrontEngineInterface::GetSentenceIds(const std::string &sentence, + std::vector *phoneids, + std::vector *toneids) { + std::vector> + cut_result; //分词结果包含词和词性 + if (0 != Cut(sentence, &cut_result)) { LOG(ERROR) << "Cut sentence: \"" << sentence << "\" failed"; return -1; } - + if (0 != GetWordsIds(cut_result, phoneids, toneids)) { LOG(ERROR) << "Get words phoneids failed"; return -1; @@ -192,81 +239,89 @@ int FrontEngineInterface::GetSentenceIds(const std::string &sentence, std::vecto return 0; } -int FrontEngineInterface::GetWordsIds(const std::vector> &cut_result, std::vector &phoneids, - std::vector &toneids) { +int FrontEngineInterface::GetWordsIds( + const std::vector> &cut_result, + std::vector *phoneids, + std::vector *toneids) { std::string word; std::string pos; std::vector word_initials; std::vector word_finals; std::string phone; - for(int i = 0; i < cut_result.size(); i++) { + for (int i = 0; i < cut_result.size(); i++) { word = cut_result[i].first; pos = cut_result[i].second; - if (std::find(_punc_omit.begin(), _punc_omit.end(), word) == _punc_omit.end()) { // 非可忽略的标点 + if (std::find(_punc_omit.begin(), _punc_omit.end(), word) == + _punc_omit.end()) { // 非可忽略的标点 word_initials = {}; word_finals = {}; phone = ""; // 判断是否在标点符号集合中 - if (std::find(_punc.begin(), _punc.end(), word) == _punc.end()) { // 文字 + if (std::find(_punc.begin(), _punc.end(), word) == + _punc.end()) { // 文字 // 获取字词的声母韵母列表 - if(0 != GetInitialsFinals(word, word_initials, word_finals)) { - LOG(ERROR) << "Genarate the word_initials and word_finals of " << word << " failed"; + if (0 != + GetInitialsFinals(word, &word_initials, &word_finals)) { + LOG(ERROR) + << "Genarate the word_initials and word_finals of " + << word << " failed"; return -1; } - + // 对读音进行修改 - if(0 != ModifyTone(word, pos, word_finals)) { + if (0 != ModifyTone(word, pos, &word_finals)) { LOG(ERROR) << "Failed to modify tone."; } // 对儿化音进行修改 - std::vector> new_initals_finals = MergeErhua(word_initials, word_finals, word, pos); + std::vector> new_initals_finals = + MergeErhua(word_initials, word_finals, word, pos); word_initials = new_initals_finals[0]; word_finals = new_initals_finals[1]; - + // 将声母和韵母合并成音素 assert(word_initials.size() == word_finals.size()); std::string temp_phone; - for(int j = 0; j < word_initials.size(); j++) { - if(word_initials[j] != "") { + for (int j = 0; j < word_initials.size(); j++) { + if (word_initials[j] != "") { temp_phone = word_initials[j] + " " + word_finals[j]; } else { temp_phone = word_finals[j]; } - if(j == 0) { + if (j == 0) { phone += temp_phone; } else { phone += (" " + temp_phone); } } - } else { // 标点符号 - if(_seperate_tone == "true") { - phone = "sp0"; // speedyspeech + } else { // 标点符号 + if (_seperate_tone == "true") { + phone = "sp0"; // speedyspeech } else { - phone = "sp"; // fastspeech2 - } + phone = "sp"; // fastspeech2 + } } // 音素到音素id - if(0 != Phone2Phoneid(phone, phoneids, toneids)) { - LOG(ERROR) << "Genarate the phone id of " << word << " failed"; + if (0 != Phone2Phoneid(phone, phoneids, toneids)) { + LOG(ERROR) << "Genarate the phone id of " << word << " failed"; return -1; } } } - return 0; - } -int FrontEngineInterface::Cut(const std::string &sentence, std::vector> &cut_result) { +int FrontEngineInterface::Cut( + const std::string &sentence, + std::vector> *cut_result) { std::vector> cut_result_jieba; - + // 结巴分词 _jieba->Tag(sentence, cut_result_jieba); // 对分词后结果进行整合 - if (0 != MergeforModify(cut_result_jieba, cut_result)) { + if (0 != MergeforModify(&cut_result_jieba, cut_result)) { LOG(ERROR) << "Failed to modify for word segmentation result."; return -1; } @@ -274,50 +329,57 @@ int FrontEngineInterface::Cut(const std::string &sentence, std::vector wordcut; _jieba->CutAll(word, wordcut); - phone = word_phone_map[wordcut[0]]; + phone->assign(word_phone_map[wordcut[0]]); for (int i = 1; i < wordcut.size(); i++) { - phone += (" " + word_phone_map[wordcut[i]]); + phone->assign((*phone) + (" " + word_phone_map[wordcut[i]])); } } else { - phone = word_phone_map[word]; + phone->assign(word_phone_map[word]); } return 0; } -int FrontEngineInterface::Phone2Phoneid(const std::string &phone, std::vector &phoneid, std::vector &toneid) { +int FrontEngineInterface::Phone2Phoneid(const std::string &phone, + std::vector *phoneid, + std::vector *toneid) { std::vector phone_vec; phone_vec = absl::StrSplit(phone, " "); std::string temp_phone; - for(int i = 0; i < phone_vec.size(); i++) { + for (int i = 0; i < phone_vec.size(); i++) { temp_phone = phone_vec[i]; - if(_seperate_tone == "true") { - phoneid.push_back(atoi((phone_id_map[temp_phone.substr(0, temp_phone.length()-1)]).c_str())); - toneid.push_back(atoi((tone_id_map[temp_phone.substr(temp_phone.length()-1, temp_phone.length())]).c_str())); - }else { - phoneid.push_back(atoi((phone_id_map[temp_phone]).c_str())); + if (_seperate_tone == "true") { + phoneid->push_back(atoi( + (phone_id_map[temp_phone.substr(0, temp_phone.length() - 1)]) + .c_str())); + toneid->push_back( + atoi((tone_id_map[temp_phone.substr(temp_phone.length() - 1, + temp_phone.length())]) + .c_str())); + } else { + phoneid->push_back(atoi((phone_id_map[temp_phone]).c_str())); } - } return 0; } // 根据韵母判断该词中每个字的读音都为第三声。true表示词中每个字都是第三声 -bool FrontEngineInterface::AllToneThree(const std::vector &finals) { +bool FrontEngineInterface::AllToneThree( + const std::vector &finals) { bool flags = true; - for(int i = 0; i < finals.size(); i++) { - if((int)finals[i].back() != 51) { //如果读音不为第三声 + for (int i = 0; i < finals.size(); i++) { + if (static_cast(finals[i].back()) != 51) { //如果读音不为第三声 flags = false; - } + } } return flags; - } // 判断词是否是叠词 @@ -325,45 +387,49 @@ bool FrontEngineInterface::IsReduplication(const std::string &word) { bool flags = false; std::wstring word_wstr = ppspeech::utf8string2wstring(word); int len = word_wstr.length(); - if(len == 2 && word_wstr[0] == word_wstr[1]){ + if (len == 2 && word_wstr[0] == word_wstr[1]) { flags = true; } return flags; - } -// 获取每个字词的声母和韵母列表, word_initials 为声母列表,word_finals 为韵母列表 -int FrontEngineInterface::GetInitialsFinals(const std::string &word, std::vector &word_initials, std::vector &word_finals) { - std::string phone; - GetPhone(word, phone); //获取字词对应的音素 +// 获取每个字词的声母和韵母列表, word_initials 为声母列表,word_finals +// 为韵母列表 +int FrontEngineInterface::GetInitialsFinals( + const std::string &word, + std::vector *word_initials, + std::vector *word_finals) { + std::string phone; + GetPhone(word, &phone); //获取字词对应的音素 std::vector phone_vec = absl::StrSplit(phone, " "); //获取韵母,每个字的音素有1或者2个,start为单个字音素的起始位置。 - int start = 0; - while(start < phone_vec.size()) { - if(phone_vec[start] == "sp" || phone_vec[start] == "sp0") { + int start = 0; + while (start < phone_vec.size()) { + if (phone_vec[start] == "sp" || phone_vec[start] == "sp0") { start += 1; - } - // 最后一位不是数字或者最后一位的数字是0,均表示声母,第二个是韵母 - else if(isdigit(phone_vec[start].back()) == 0 || (int)phone_vec[start].back() == 48) { - word_initials.push_back(phone_vec[start]); - word_finals.push_back(phone_vec[start + 1]); + } else if (isdigit(phone_vec[start].back()) == 0 || + static_cast(phone_vec[start].back()) == 48) { + word_initials->push_back(phone_vec[start]); + word_finals->push_back(phone_vec[start + 1]); start += 2; } else { - word_initials.push_back(""); - word_finals.push_back(phone_vec[start]); + word_initials->push_back(""); + word_finals->push_back(phone_vec[start]); start += 1; } } - - assert(word_finals.size() == ppspeech::utf8string2wstring(word).length() && word_finals.size() == word_initials.size()); + + assert(word_finals->size() == ppspeech::utf8string2wstring(word).length() && + word_finals->size() == word_initials->size()); return 0; } // 获取每个字词的韵母列表 -int FrontEngineInterface::GetFinals(const std::string &word, std::vector &word_finals) { +int FrontEngineInterface::GetFinals(const std::string &word, + std::vector *word_finals) { std::vector word_initials; - if(0 != GetInitialsFinals(word, word_initials, word_finals)) { + if (0 != GetInitialsFinals(word, &word_initials, word_finals)) { LOG(ERROR) << "Failed to get word finals"; return -1; } @@ -371,162 +437,189 @@ int FrontEngineInterface::GetFinals(const std::string &word, std::vector &wordvec) { +int FrontEngineInterface::Word2WordVec(const std::string &word, + std::vector *wordvec) { std::wstring word_wstr = ppspeech::utf8string2wstring(word); - for(int i = 0; i < word_wstr.length(); i++) { + for (int i = 0; i < word_wstr.length(); i++) { std::wstring word_sigle(1, word_wstr[i]); - wordvec.push_back(word_sigle); + wordvec->push_back(word_sigle); } return 0; - } // yuantian01解释:把一个词再进行分词找到。例子:小雨伞 --> 小 雨伞 或者 小雨 伞 -int FrontEngineInterface::SplitWord(const std::string &word, std::vector &new_word_vec) { +int FrontEngineInterface::SplitWord(const std::string &word, + std::vector *new_word_vec) { std::vector word_vec; std::string second_subword; _jieba->CutForSearch(word, word_vec); // 升序 - std::sort(word_vec.begin(), word_vec.end(), [](std::string a, std::string b ) {return a.size() > b.size();}); + std::sort(word_vec.begin(), + word_vec.end(), + [](std::string a, std::string b) { return a.size() > b.size(); }); std::string first_subword = word_vec[0]; // 提取长度最短的字符串 int first_begin_idx = word.find_first_of(first_subword); - if(first_begin_idx == 0) { + if (first_begin_idx == 0) { second_subword = word.substr(first_subword.length()); - new_word_vec.push_back(first_subword); - new_word_vec.push_back(second_subword); + new_word_vec->push_back(first_subword); + new_word_vec->push_back(second_subword); } else { second_subword = word.substr(0, word.length() - first_subword.length()); - new_word_vec.push_back(second_subword); - new_word_vec.push_back(first_subword); + new_word_vec->push_back(second_subword); + new_word_vec->push_back(first_subword); } return 0; - } -//example: 不 一起 --> 不一起 -std::vector> FrontEngineInterface::MergeBu(std::vector> &seg_result) { +// example: 不 一起 --> 不一起 +std::vector> FrontEngineInterface::MergeBu( + std::vector> *seg_result) { std::vector> result; std::string word; std::string pos; std::string last_word = ""; - - for(int i = 0; i < seg_result.size(); i++) { - word = seg_result[i].first; - pos = seg_result[i].second; - if(last_word == "不") { + + for (int i = 0; i < seg_result->size(); i++) { + word = std::get<0>((*seg_result)[i]); + pos = std::get<1>((*seg_result)[i]); + if (last_word == "不") { word = last_word + word; - } - if(word != "不") { + } + if (word != "不") { result.push_back(make_pair(word, pos)); - } + } last_word = word; } - if(last_word == "不") { + if (last_word == "不") { result.push_back(make_pair(last_word, "d")); last_word = ""; } - + return result; } -std::vector> FrontEngineInterface::Mergeyi(std::vector> &seg_result) { - std::vector> result_temp; +std::vector> FrontEngineInterface::Mergeyi( + std::vector> *seg_result) { + std::vector> *result_temp = + new std::vector>(); std::string word; std::string pos; - // function 1 example: 听 一 听 --> 听一听 - for(int i = 0; i < seg_result.size(); i++) { - word = seg_result[i].first; - pos = seg_result[i].second; - if((i - 1 >= 0) && (word == "一") && (i + 1 < seg_result.size()) && - (seg_result[i - 1].first == seg_result[i + 1].first) && seg_result[i - 1].second == "v") { - result_temp[i - 1].first = result_temp[i - 1].first + "一" + result_temp[i - 1].first; + for (int i = 0; i < seg_result->size(); i++) { + word = std::get<0>((*seg_result)[i]); + pos = std::get<1>((*seg_result)[i]); + + if ((i - 1 >= 0) && (word == "一") && (i + 1 < seg_result->size()) && + (std::get<0>((*seg_result)[i - 1]) == + std::get<0>((*seg_result)[i + 1])) && + std::get<1>((*seg_result)[i - 1]) == "v") { + std::get<0>((*result_temp)[i - 1]) = + std::get<0>((*result_temp)[i - 1]) + "一" + + std::get<0>((*result_temp)[i - 1]); + } else { + if ((i - 2 >= 0) && (std::get<0>((*seg_result)[i - 1]) == "一") && + (std::get<0>((*seg_result)[i - 2]) == word) && (pos == "v")) { + continue; } else { - if((i - 2 >= 0) && (seg_result[i - 1].first == "一") && (seg_result[i - 2].first == word) && (pos == "v")) { - continue; - } else{ - result_temp.push_back(make_pair(word, pos)); - } - } + result_temp->push_back(make_pair(word, pos)); + } + } } // function 2 example: 一 你 --> 一你 std::vector> result = {}; - for(int j = 0; j < result_temp.size(); j++) { - word = result_temp[j].first; - pos = result_temp[j].second; - if((result.size() != 0) && (result.back().first == "一")) { + for (int j = 0; j < result_temp->size(); j++) { + word = std::get<0>((*result_temp)[j]); + pos = std::get<1>((*result_temp)[j]); + if ((result.size() != 0) && (result.back().first == "一")) { result.back().first = result.back().first + word; } else { result.push_back(make_pair(word, pos)); - } - + } } - + return result; } // example: 你 你 --> 你你 -std::vector> FrontEngineInterface::MergeReduplication(std::vector> &seg_result) { +std::vector> +FrontEngineInterface::MergeReduplication( + std::vector> *seg_result) { std::vector> result; std::string word; std::string pos; - for(int i = 0; i < seg_result.size(); i++) { - word = seg_result[i].first; - pos = seg_result[i].second; - if((result.size() != 0) && (word == result.back().first)) { - result.back().first = result.back().first + seg_result[i].first; + for (int i = 0; i < seg_result->size(); i++) { + word = std::get<0>((*seg_result)[i]); + pos = std::get<1>((*seg_result)[i]); + if ((result.size() != 0) && (word == result.back().first)) { + result.back().first = + result.back().first + std::get<0>((*seg_result)[i]); } else { result.push_back(make_pair(word, pos)); } } - + return result; } -// the first and the second words are all_tone_three -std::vector> FrontEngineInterface::MergeThreeTones(std::vector> &seg_result) { +// the first and the second words are all_tone_three +std::vector> +FrontEngineInterface::MergeThreeTones( + std::vector> *seg_result) { std::vector> result; std::string word; - std::string pos; - std::vector> finals; //韵母数组 + std::string pos; + std::vector> finals; //韵母数组 std::vector word_final; - std::vector merge_last(seg_result.size(), false); + std::vector merge_last(seg_result->size(), false); // 判断最后一个分词结果是不是标点,不看标点的声母韵母 - int word_num = seg_result.size() - 1; - if(std::find(_punc.begin(), _punc.end(), seg_result[word_num].first) == _punc.end()){ // 最后一个分词结果不是标点 + int word_num = seg_result->size() - 1; + + // seg_result[word_num].first + if (std::find( + _punc.begin(), _punc.end(), std::get<0>((*seg_result)[word_num])) == + _punc.end()) { // 最后一个分词结果不是标点 word_num += 1; } // 获取韵母数组 - for(int i = 0; i < word_num; i++) { + for (int i = 0; i < word_num; i++) { word_final = {}; - word = seg_result[i].first; - pos = seg_result[i].second; - if(std::find(_punc_omit.begin(), _punc_omit.end(), word) == _punc_omit.end()) { // 非可忽略的标点,即文字 - if(0 != GetFinals(word, word_final)) { + word = std::get<0>((*seg_result)[i]); + pos = std::get<1>((*seg_result)[i]); + if (std::find(_punc_omit.begin(), _punc_omit.end(), word) == + _punc_omit.end()) { // 非可忽略的标点,即文字 + if (0 != GetFinals(word, &word_final)) { LOG(ERROR) << "Failed to get the final of word."; } - } + } - finals.push_back(word_final); + finals.push_back(word_final); } assert(word_num == finals.size()); // 对第三声读音的字词分词结果进行处理 - for(int i = 0; i < word_num; i++) { - word = seg_result[i].first; - pos = seg_result[i].second; - if(i - 1 >= 0 && AllToneThree(finals[i - 1]) && AllToneThree(finals[i]) && !merge_last[i - 1]) { - // if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi - if(!IsReduplication(seg_result[i - 1].first) && - (ppspeech::utf8string2wstring(seg_result[i - 1].first)).length() + (ppspeech::utf8string2wstring(word)).length() <= 3) { - result.back().first = result.back().first + seg_result[i].first; + for (int i = 0; i < word_num; i++) { + word = std::get<0>((*seg_result)[i]); + pos = std::get<1>((*seg_result)[i]); + if (i - 1 >= 0 && AllToneThree(finals[i - 1]) && + AllToneThree(finals[i]) && !merge_last[i - 1]) { + // if the last word is reduplication, not merge, because + // reduplication need to be _neural_sandhi + // seg_result[i - 1].first + if (!IsReduplication(std::get<0>((*seg_result)[i - 1])) && + (ppspeech::utf8string2wstring( + std::get<0>((*seg_result)[i - 1]))) + .length() + + (ppspeech::utf8string2wstring(word)).length() <= + 3) { + result.back().first = + result.back().first + std::get<0>((*seg_result)[i]); merge_last[i] = true; } else { result.push_back(make_pair(word, pos)); @@ -537,54 +630,73 @@ std::vector> FrontEngineInterface::MergeThre } //把标点的分词结果补上 - if(word_num < seg_result.size()) { - result.push_back(make_pair(seg_result[word_num].first, seg_result[word_num].second)); + if (word_num < seg_result->size()) { + result.push_back( + // seg_result[word_num].first seg_result[word_num].second + // std::get<0>((*seg_result)[word_num]) + make_pair(std::get<0>((*seg_result)[word_num]), + std::get<1>((*seg_result)[word_num]))); } return result; } -// the last char of first word and the first char of second word is tone_three -std::vector> FrontEngineInterface::MergeThreeTones2(std::vector> &seg_result) { +// the last char of first word and the first char of second word is tone_three +std::vector> +FrontEngineInterface::MergeThreeTones2( + std::vector> *seg_result) { std::vector> result; std::string word; - std::string pos; - std::vector> finals; //韵母数组 + std::string pos; + std::vector> finals; //韵母数组 std::vector word_final; - std::vector merge_last(seg_result.size(), false); + std::vector merge_last(seg_result->size(), false); // 判断最后一个分词结果是不是标点 - int word_num = seg_result.size() - 1; - if(std::find(_punc.begin(), _punc.end(), seg_result[word_num].first) == _punc.end()){ // 最后一个分词结果不是标点 + int word_num = seg_result->size() - 1; + if (std::find( + _punc.begin(), _punc.end(), std::get<0>((*seg_result)[word_num])) == + _punc.end()) { // 最后一个分词结果不是标点 word_num += 1; } // 获取韵母数组 - for(int i = 0; i < word_num; i++) { + for (int i = 0; i < word_num; i++) { word_final = {}; - word = seg_result[i].first; - pos = seg_result[i].second; + word = std::get<0>((*seg_result)[i]); + pos = std::get<1>((*seg_result)[i]); // 如果是文字,则获取韵母,如果是可忽略的标点,例如引号,则跳过 - if(std::find(_punc_omit.begin(), _punc_omit.end(), word) == _punc_omit.end()) { - if(0 != GetFinals(word, word_final)) { + if (std::find(_punc_omit.begin(), _punc_omit.end(), word) == + _punc_omit.end()) { + if (0 != GetFinals(word, &word_final)) { LOG(ERROR) << "Failed to get the final of word."; } - } + } - finals.push_back(word_final); + finals.push_back(word_final); } assert(word_num == finals.size()); // 对第三声读音的字词分词结果进行处理 - for(int i = 0; i < word_num; i++) { - word = seg_result[i].first; - pos = seg_result[i].second; - if(i - 1 >= 0 && !finals[i - 1].empty() && absl::EndsWith(finals[i - 1].back(), "3") == true && - !finals[i].empty() && absl::EndsWith(finals[i].front(), "3") == true && !merge_last[i - 1]) { - // if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi - if(!IsReduplication(seg_result[i - 1].first) && - (ppspeech::utf8string2wstring(seg_result[i - 1].first)).length() + ppspeech::utf8string2wstring(word).length() <= 3) { - result.back().first = result.back().first + seg_result[i].first; + for (int i = 0; i < word_num; i++) { + word = std::get<0>((*seg_result)[i]); + pos = std::get<1>((*seg_result)[i]); + if (i - 1 >= 0 && !finals[i - 1].empty() && + absl::EndsWith(finals[i - 1].back(), "3") == true && + !finals[i].empty() && + absl::EndsWith(finals[i].front(), "3") == true && + !merge_last[i - 1]) { + // if the last word is reduplication, not merge, because + // reduplication need to be _neural_sandhi + // seg_result[i - 1].first + if (!IsReduplication(std::get<0>((*seg_result)[i - 1])) && + (ppspeech::utf8string2wstring( + std::get<0>((*seg_result)[i - 1]))) + .length() + + ppspeech::utf8string2wstring(word).length() <= + 3) { + result.back().first = + result.back().first + std::get<0>((*seg_result)[i]); merge_last[i] = true; } else { result.push_back(make_pair(word, pos)); @@ -595,73 +707,86 @@ std::vector> FrontEngineInterface::MergeThre } //把标点的分词结果补上 - if(word_num < seg_result.size()) { - result.push_back(make_pair(seg_result[word_num].first, seg_result[word_num].second)); + if (word_num < seg_result->size()) { + result.push_back(make_pair(std::get<0>((*seg_result)[word_num]), + std::get<1>((*seg_result)[word_num]))); } return result; } // example: 吃饭 儿 --> 吃饭儿 -std::vector> FrontEngineInterface::MergeEr(std::vector> &seg_result) { +std::vector> FrontEngineInterface::MergeEr( + std::vector> *seg_result) { std::vector> result; std::string word; std::string pos; - for(int i = 0; i < seg_result.size(); i++) { - word = seg_result[i].first; - pos = seg_result[i].second; - if((i - 1 >= 0) && (word == "儿")){ - result.back().first = result.back().first + seg_result[i].first; + for (int i = 0; i < seg_result->size(); i++) { + word = std::get<0>((*seg_result)[i]); + pos = std::get<1>((*seg_result)[i]); + if ((i - 1 >= 0) && (word == "儿")) { + result.back().first = + result.back().first + std::get<0>((*seg_result)[i]); } else { - result.push_back(make_pair(word, pos)); + result.push_back(make_pair(word, pos)); } } return result; } -int FrontEngineInterface::MergeforModify(std::vector> &seg_word_type, - std::vector> &modify_seg_word_type) { - +int FrontEngineInterface::MergeforModify( + std::vector> *seg_word_type, + std::vector> *modify_seg_word_type) { std::vector seg_result; - GetSegResult(seg_word_type, seg_result); - LOG(INFO) << "Before merge, seg result is: " << limonp::Join(seg_result.begin(), seg_result.end(), "/"); - - modify_seg_word_type = MergeBu(seg_word_type); - modify_seg_word_type = Mergeyi(modify_seg_word_type); - modify_seg_word_type = MergeReduplication(modify_seg_word_type); - modify_seg_word_type = MergeThreeTones(modify_seg_word_type); - modify_seg_word_type = MergeThreeTones2(modify_seg_word_type); - modify_seg_word_type = MergeEr(modify_seg_word_type); - + GetSegResult(seg_word_type, &seg_result); + LOG(INFO) << "Before merge, seg result is: " + << limonp::Join(seg_result.begin(), seg_result.end(), "/"); + std::vector> tmp; + tmp = MergeBu(seg_word_type); + *modify_seg_word_type = tmp; + tmp = Mergeyi(modify_seg_word_type); + *modify_seg_word_type = tmp; + tmp = MergeReduplication(modify_seg_word_type); + *modify_seg_word_type = tmp; + tmp = MergeThreeTones(modify_seg_word_type); + *modify_seg_word_type = tmp; + tmp = MergeThreeTones2(modify_seg_word_type); + *modify_seg_word_type = tmp; + tmp = MergeEr(modify_seg_word_type); + *modify_seg_word_type = tmp; seg_result = {}; - GetSegResult(modify_seg_word_type, seg_result); - LOG(INFO) << "After merge, seg result is: " << limonp::Join(seg_result.begin(), seg_result.end(), "/"); + + GetSegResult(modify_seg_word_type, &seg_result); + LOG(INFO) << "After merge, seg result is: " + << limonp::Join(seg_result.begin(), seg_result.end(), "/"); return 0; } -int FrontEngineInterface::BuSandi(const std::string &word, std::vector &finals) { +int FrontEngineInterface::BuSandi(const std::string &word, + std::vector *finals) { std::wstring bu = L"不"; std::vector wordvec; // 一个词转成向量形式 - if(0 != Word2WordVec(word, wordvec)) { + if (0 != Word2WordVec(word, &wordvec)) { LOG(ERROR) << "Failed to get word vector"; return -1; } // e.g. 看不懂 b u4 --> b u5, 将韵母的最后一位替换成 5 - if(wordvec.size() == 3 && wordvec[1] == bu) { - finals[1] = finals[1].replace(finals[1].length() - 1, 1, "5"); + if (wordvec.size() == 3 && wordvec[1] == bu) { + (*finals)[1] = (*finals)[1].replace((*finals)[1].length() - 1, 1, "5"); } else { // e.g. 不怕 b u4 --> b u2, 将韵母的最后一位替换成 2 - for(int i = 0; i < wordvec.size(); i++) { - if(wordvec[i] == bu && i + 1 < wordvec.size() && - absl::EndsWith(finals[i + 1], "4") == true) { - finals[i] = finals[i].replace(finals[i].length() - 1, 1, "2"); - } + for (int i = 0; i < wordvec.size(); i++) { + if (wordvec[i] == bu && i + 1 < wordvec.size() && + absl::EndsWith((*finals)[i + 1], "4") == true) { + (*finals)[i] = + (*finals)[i].replace((*finals)[i].length() - 1, 1, "2"); + } } } @@ -669,11 +794,12 @@ int FrontEngineInterface::BuSandi(const std::string &word, std::vector &finals) { +int FrontEngineInterface::YiSandhi(const std::string &word, + std::vector *finals) { std::wstring yi = L"一"; std::vector wordvec; // 一个词转成向量形式 - if(0 != Word2WordVec(word, wordvec)) { + if (0 != Word2WordVec(word, &wordvec)) { LOG(ERROR) << "Failed to get word vector"; return -1; } @@ -681,44 +807,49 @@ int FrontEngineInterface::YiSandhi(const std::string &word, std::vector &finals) { +int FrontEngineInterface::NeuralSandhi(const std::string &word, + const std::string &pos, + std::vector *finals) { std::wstring word_wstr = ppspeech::utf8string2wstring(word); std::vector wordvec; // 一个词转成向量形式 - if(0 != Word2WordVec(word, wordvec)) { + if (0 != Word2WordVec(word, &wordvec)) { LOG(ERROR) << "Failed to get word vector"; return -1; } @@ -726,10 +857,12 @@ int FrontEngineInterface::NeuralSandhi(const std::string &word, const std::strin assert(word_num == word_wstr.length()); // 情况1:reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺 - for(int j = 0; j < wordvec.size(); j++) { + for (int j = 0; j < wordvec.size(); j++) { std::string inits = "nva"; - if(j - 1 >= 0 && wordvec[j] == wordvec[j - 1] && inits.find(pos[0]) != inits.npos) { - finals[j] = finals[j].replace(finals[j].length() - 1, 1, "5"); + if (j - 1 >= 0 && wordvec[j] == wordvec[j - 1] && + inits.find(pos[0]) != inits.npos) { + (*finals)[j] = + (*finals)[j].replace((*finals)[j].length() - 1, 1, "5"); } } @@ -747,147 +880,204 @@ int FrontEngineInterface::NeuralSandhi(const std::string &word, const std::strin std::wstring ge = L"个"; std::wstring xiushi = L"几有两半多各整每做是零一二三四六七八九"; auto ge_idx = word_wstr.find_first_of(ge); // 出现“个”的第一个位置 - - if(word_num >= 1 && yuqici.find(wordvec.back()) != yuqici.npos) { - finals.back() = finals.back().replace(finals.back().length() - 1, 1, "5"); - } else if(word_num >= 1 && de.find(wordvec.back()) != de.npos) { - finals.back() = finals.back().replace(finals.back().length() - 1, 1, "5"); - } else if(word_num == 1 && le.find(wordvec[0]) != le.npos && find(le_pos.begin(), le_pos.end(), pos) != le_pos.end()) { - finals.back() = finals.back().replace(finals.back().length() - 1, 1, "5"); - } else if(word_num > 1 && men.find(wordvec.back()) != men.npos && find(men_pos.begin(), men_pos.end(), pos) != men_pos.end() - && find(must_not_neural_tone_words.begin(), must_not_neural_tone_words.end(), word) != must_not_neural_tone_words.end()) { - finals.back() = finals.back().replace(finals.back().length() - 1, 1, "5"); - } else if(word_num > 1 && weizhi.find(wordvec.back()) != weizhi.npos && find(weizhi_pos.begin(), weizhi_pos.end(), pos) != weizhi_pos.end()) { - finals.back() = finals.back().replace(finals.back().length() - 1, 1, "5"); - } else if(word_num > 1 && dong.find(wordvec.back()) != dong.npos && fangxiang.find(wordvec[word_num - 2]) != fangxiang.npos) { - finals.back() = finals.back().replace(finals.back().length() - 1, 1, "5"); - } - // 情况3:对“个”字前面带有修饰词的字词读音处理 - else if((ge_idx != word_wstr.npos && ge_idx >= 1 && xiushi.find(wordvec[ge_idx - 1]) != xiushi.npos) - || word_wstr == ge) { - finals.back() = finals.back().replace(finals.back().length() - 1, 1, "5"); + + if (word_num >= 1 && yuqici.find(wordvec.back()) != yuqici.npos) { + (*finals).back() = + (*finals).back().replace((*finals).back().length() - 1, 1, "5"); + } else if (word_num >= 1 && de.find(wordvec.back()) != de.npos) { + (*finals).back() = + (*finals).back().replace((*finals).back().length() - 1, 1, "5"); + } else if (word_num == 1 && le.find(wordvec[0]) != le.npos && + find(le_pos.begin(), le_pos.end(), pos) != le_pos.end()) { + (*finals).back() = + (*finals).back().replace((*finals).back().length() - 1, 1, "5"); + } else if (word_num > 1 && men.find(wordvec.back()) != men.npos && + find(men_pos.begin(), men_pos.end(), pos) != men_pos.end() && + find(must_not_neural_tone_words.begin(), + must_not_neural_tone_words.end(), + word) != must_not_neural_tone_words.end()) { + (*finals).back() = + (*finals).back().replace((*finals).back().length() - 1, 1, "5"); + } else if (word_num > 1 && weizhi.find(wordvec.back()) != weizhi.npos && + find(weizhi_pos.begin(), weizhi_pos.end(), pos) != + weizhi_pos.end()) { + (*finals).back() = + (*finals).back().replace((*finals).back().length() - 1, 1, "5"); + } else if (word_num > 1 && dong.find(wordvec.back()) != dong.npos && + fangxiang.find(wordvec[word_num - 2]) != fangxiang.npos) { + (*finals).back() = + (*finals).back().replace((*finals).back().length() - 1, 1, "5"); + } else if ((ge_idx != word_wstr.npos && ge_idx >= 1 && + xiushi.find(wordvec[ge_idx - 1]) != xiushi.npos) || + word_wstr == ge) { + (*finals).back() = + (*finals).back().replace((*finals).back().length() - 1, 1, "5"); } else { - if(find(must_neural_tone_words.begin(), must_neural_tone_words.end(), word) != must_neural_tone_words.end() - || (word_num >= 2 && find(must_neural_tone_words.begin(), must_neural_tone_words.end(), ppspeech::wstring2utf8string(word_wstr.substr(word_num - 2))) != must_neural_tone_words.end())) { - finals.back() = finals.back().replace(finals.back().length() - 1, 1, "5"); - } + if (find(must_neural_tone_words.begin(), + must_neural_tone_words.end(), + word) != must_neural_tone_words.end() || + (word_num >= 2 && + find(must_neural_tone_words.begin(), + must_neural_tone_words.end(), + ppspeech::wstring2utf8string(word_wstr.substr( + word_num - 2))) != must_neural_tone_words.end())) { + (*finals).back() = + (*finals).back().replace((*finals).back().length() - 1, 1, "5"); + } } // 进行进一步分词,把长词切分更短些 std::vector word_list; - if(0 != SplitWord(word, word_list)) { + if (0 != SplitWord(word, &word_list)) { LOG(ERROR) << "Failed to split word."; return -1; } // 创建对应的 韵母列表 std::vector> finals_list; std::vector finals_temp; - finals_temp.assign(finals.begin(), finals.begin() + ppspeech::utf8string2wstring(word_list[0]).length()); + finals_temp.assign((*finals).begin(), + (*finals).begin() + + ppspeech::utf8string2wstring(word_list[0]).length()); finals_list.push_back(finals_temp); - finals_temp.assign(finals.begin() + ppspeech::utf8string2wstring(word_list[0]).length(), finals.end()); + finals_temp.assign( + (*finals).begin() + ppspeech::utf8string2wstring(word_list[0]).length(), + (*finals).end()); finals_list.push_back(finals_temp); - finals = {}; - for(int i = 0; i < word_list.size(); i++) { + finals = new std::vector(); + for (int i = 0; i < word_list.size(); i++) { std::wstring temp_wstr = ppspeech::utf8string2wstring(word_list[i]); - if((find(must_neural_tone_words.begin(), must_neural_tone_words.end(), word_list[i]) != must_neural_tone_words.end()) - || (temp_wstr.length() >= 2 && find(must_neural_tone_words.begin(), must_neural_tone_words.end(), ppspeech::wstring2utf8string(temp_wstr.substr(temp_wstr.length() - 2))) != must_neural_tone_words.end())) { - finals_list[i].back() = finals_list[i].back().replace(finals_list[i].back().length() - 1, 1, "5"); - } - finals.insert(finals.end(), finals_list[i].begin(), finals_list[i].end()); + if ((find(must_neural_tone_words.begin(), + must_neural_tone_words.end(), + word_list[i]) != must_neural_tone_words.end()) || + (temp_wstr.length() >= 2 && + find(must_neural_tone_words.begin(), + must_neural_tone_words.end(), + ppspeech::wstring2utf8string( + temp_wstr.substr(temp_wstr.length() - 2))) != + must_neural_tone_words.end())) { + finals_list[i].back() = finals_list[i].back().replace( + finals_list[i].back().length() - 1, 1, "5"); + } + (*finals).insert( + (*finals).end(), finals_list[i].begin(), finals_list[i].end()); } return 0; } -int FrontEngineInterface::ThreeSandhi(const std::string &word, std::vector &finals) { +int FrontEngineInterface::ThreeSandhi(const std::string &word, + std::vector *finals) { std::wstring word_wstr = ppspeech::utf8string2wstring(word); std::vector> finals_list; std::vector finals_temp; std::vector wordvec; // 一个词转成向量形式 - if(0 != Word2WordVec(word, wordvec)) { + if (0 != Word2WordVec(word, &wordvec)) { LOG(ERROR) << "Failed to get word vector"; return -1; } int word_num = wordvec.size(); assert(word_num == word_wstr.length()); - if(word_num == 2 && AllToneThree(finals)) { - finals[0] = finals[0].replace(finals[0].length() - 1, 1, "2"); - } else if(word_num == 3) { + if (word_num == 2 && AllToneThree((*finals))) { + (*finals)[0] = (*finals)[0].replace((*finals)[0].length() - 1, 1, "2"); + } else if (word_num == 3) { // 进行进一步分词,把长词切分更短些 std::vector word_list; - if(0 != SplitWord(word, word_list)) { + if (0 != SplitWord(word, &word_list)) { LOG(ERROR) << "Failed to split word."; return -1; } - if(AllToneThree(finals)) { + if (AllToneThree((*finals))) { std::wstring temp_wstr = ppspeech::utf8string2wstring(word_list[0]); - //disyllabic + monosyllabic, e.g. 蒙古/包 - if(temp_wstr.length() == 2) { - finals[0] = finals[0].replace(finals[0].length() - 1, 1, "2"); - finals[1] = finals[1].replace(finals[1].length() - 1, 1, "2"); - } else if(temp_wstr.length() == 1) { //monosyllabic + disyllabic, e.g. 纸/老虎 - finals[1] = finals[1].replace(finals[1].length() - 1, 1, "2"); + // disyllabic + monosyllabic, e.g. 蒙古/包 + if (temp_wstr.length() == 2) { + (*finals)[0] = + (*finals)[0].replace((*finals)[0].length() - 1, 1, "2"); + (*finals)[1] = + (*finals)[1].replace((*finals)[1].length() - 1, 1, "2"); + } else if (temp_wstr.length() == + 1) { // monosyllabic + disyllabic, e.g. 纸/老虎 + (*finals)[1] = + (*finals)[1].replace((*finals)[1].length() - 1, 1, "2"); } } else { // 创建对应的 韵母列表 finals_temp = {}; finals_list = {}; - finals_temp.assign(finals.begin(), finals.begin() + ppspeech::utf8string2wstring(word_list[0]).length()); + finals_temp.assign( + (*finals).begin(), + (*finals).begin() + + ppspeech::utf8string2wstring(word_list[0]).length()); finals_list.push_back(finals_temp); - finals_temp.assign(finals.begin() + ppspeech::utf8string2wstring(word_list[0]).length(), finals.end()); + finals_temp.assign( + (*finals).begin() + + ppspeech::utf8string2wstring(word_list[0]).length(), + (*finals).end()); finals_list.push_back(finals_temp); - - finals = {}; - for(int i = 0; i < finals_list.size(); i++) { + + finals = new std::vector(); + for (int i = 0; i < finals_list.size(); i++) { // e.g. 所有/人 - if(AllToneThree(finals_list[i]) && finals_list[i].size() == 2) { - finals_list[i][0] = finals_list[i][0].replace(finals_list[i][0].length() - 1, 1, "2"); - } else if(i == 1 && !(AllToneThree(finals_list[i])) && absl::EndsWith(finals_list[i][0], "3") == true - && absl::EndsWith(finals_list[0].back(), "3") == true) { - finals_list[0].back() = finals_list[0].back().replace(finals_list[0].back().length() - 1, 1, "2"); - } - + if (AllToneThree(finals_list[i]) && + finals_list[i].size() == 2) { + finals_list[i][0] = finals_list[i][0].replace( + finals_list[i][0].length() - 1, 1, "2"); + } else if (i == 1 && !(AllToneThree(finals_list[i])) && + absl::EndsWith(finals_list[i][0], "3") == true && + absl::EndsWith(finals_list[0].back(), "3") == true) { + finals_list[0].back() = finals_list[0].back().replace( + finals_list[0].back().length() - 1, 1, "2"); + } } - finals.insert(finals.end(), finals_list[0].begin(), finals_list[0].end()); - finals.insert(finals.end(), finals_list[1].begin(), finals_list[1].end()); + (*finals).insert( + (*finals).end(), finals_list[0].begin(), finals_list[0].end()); + (*finals).insert( + (*finals).end(), finals_list[1].begin(), finals_list[1].end()); } - } else if(word_num == 4) { //将成语拆分为两个长度为 2 的单词 + } else if (word_num == 4) { //将成语拆分为两个长度为 2 的单词 // 创建对应的 韵母列表 finals_temp = {}; finals_list = {}; - finals_temp.assign(finals.begin(), finals.begin() + 2); + finals_temp.assign((*finals).begin(), (*finals).begin() + 2); finals_list.push_back(finals_temp); - finals_temp.assign(finals.begin() + 2, finals.end()); + finals_temp.assign((*finals).begin() + 2, (*finals).end()); finals_list.push_back(finals_temp); - finals = {}; - for(int j = 0; j < finals_list.size(); j++){ - if(AllToneThree(finals_list[j])) { - finals_list[j][0] = finals_list[j][0].replace(finals_list[j][0].length() - 1, 1, "2"); + finals = new std::vector(); + for (int j = 0; j < finals_list.size(); j++) { + if (AllToneThree(finals_list[j])) { + finals_list[j][0] = finals_list[j][0].replace( + finals_list[j][0].length() - 1, 1, "2"); } - finals.insert(finals.end(), finals_list[j].begin(), finals_list[j].end()); + (*finals).insert( + (*finals).end(), finals_list[j].begin(), finals_list[j].end()); } - } return 0; } -int FrontEngineInterface::ModifyTone(const std::string &word, const std::string &pos, std::vector &finals) { - if((0 != BuSandi(word, finals)) || (0 != YiSandhi(word, finals)) || - (0 != NeuralSandhi(word, pos, finals)) || (0 != ThreeSandhi(word,finals))) { - LOG(ERROR) << "Failed to modify tone of the word: " << word; - return -1; - } +int FrontEngineInterface::ModifyTone(const std::string &word, + const std::string &pos, + std::vector *finals) { + if ((0 != BuSandi(word, finals)) || (0 != YiSandhi(word, finals)) || + (0 != NeuralSandhi(word, pos, finals)) || + (0 != ThreeSandhi(word, finals))) { + LOG(ERROR) << "Failed to modify tone of the word: " << word; + return -1; + } return 0; } -std::vector> FrontEngineInterface::MergeErhua(const std::vector &initials, const std::vector &finals, const std::string &word, const std::string &pos) { +std::vector> FrontEngineInterface::MergeErhua( + const std::vector &initials, + const std::vector &finals, + const std::string &word, + const std::string &pos) { std::vector new_initials = {}; std::vector new_finals = {}; std::vector> new_initials_finals; @@ -895,28 +1085,38 @@ std::vector> FrontEngineInterface::MergeErhua(const std std::wstring word_wstr = ppspeech::utf8string2wstring(word); std::vector wordvec; // 一个词转成向量形式 - if(0 != Word2WordVec(word, wordvec)) { + if (0 != Word2WordVec(word, &wordvec)) { LOG(ERROR) << "Failed to get word vector"; } int word_num = wordvec.size(); - if((find(must_erhua.begin(), must_erhua.end(), word) == must_erhua.end()) && - ((find(not_erhua.begin(), not_erhua.end(), word) != not_erhua.end()) || (find(specified_pos.begin(), specified_pos.end(), pos) != specified_pos.end()))) { + if ((find(must_erhua.begin(), must_erhua.end(), word) == + must_erhua.end()) && + ((find(not_erhua.begin(), not_erhua.end(), word) != not_erhua.end()) || + (find(specified_pos.begin(), specified_pos.end(), pos) != + specified_pos.end()))) { new_initials_finals.push_back(initials); new_initials_finals.push_back(finals); return new_initials_finals; } - if(finals.size() != word_num) { + if (finals.size() != word_num) { new_initials_finals.push_back(initials); new_initials_finals.push_back(finals); return new_initials_finals; } assert(finals.size() == word_num); - for(int i = 0; i < finals.size(); i++) { - if(i == finals.size() - 1 && wordvec[i] == L"儿" && (finals[i] == "er2" || finals[i] == "er5") && word_num >= 2 && - find(not_erhua.begin(), not_erhua.end(), ppspeech::wstring2utf8string(word_wstr.substr(word_wstr.length() - 2))) == not_erhua.end() && !new_finals.empty()) { - new_finals.back() = new_finals.back().substr(0, new_finals.back().length()-1) + "r" + new_finals.back().substr(new_finals.back().length()-1); + for (int i = 0; i < finals.size(); i++) { + if (i == finals.size() - 1 && wordvec[i] == L"儿" && + (finals[i] == "er2" || finals[i] == "er5") && word_num >= 2 && + find(not_erhua.begin(), + not_erhua.end(), + ppspeech::wstring2utf8string(word_wstr.substr( + word_wstr.length() - 2))) == not_erhua.end() && + !new_finals.empty()) { + new_finals.back() = + new_finals.back().substr(0, new_finals.back().length() - 1) + + "r" + new_finals.back().substr(new_finals.back().length() - 1); } else { new_initials.push_back(initials[i]); new_finals.push_back(finals[i]); @@ -926,8 +1126,5 @@ std::vector> FrontEngineInterface::MergeErhua(const std new_initials_finals.push_back(new_finals); return new_initials_finals; - -} - - } +} // namespace ppspeech diff --git a/demos/TTSCppFrontend/src/front/front_interface.h b/demos/TTSCppFrontend/src/front/front_interface.h index 8df026c8d..fc33a4de6 100644 --- a/demos/TTSCppFrontend/src/front/front_interface.h +++ b/demos/TTSCppFrontend/src/front/front_interface.h @@ -1,156 +1,198 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #ifndef PADDLE_TTS_SERVING_FRONT_FRONT_INTERFACE_H #define PADDLE_TTS_SERVING_FRONT_FRONT_INTERFACE_H +#include +#include #include -#include #include -#include -#include +#include //#include "utils/dir_utils.h" #include -#include "front/text_normalize.h" #include "absl/strings/str_split.h" +#include "front/text_normalize.h" namespace ppspeech { - - class FrontEngineInterface : public TextNormalizer{ - public: - FrontEngineInterface(std::string conf) : _conf_file(conf) { - TextNormalizer(); - _jieba = nullptr; - _initialed = false; - init(); - } - - int init(); - ~FrontEngineInterface() { - - } - - // 读取配置文件 - int ReadConfFile(); - - // 简体转繁体 - int Trand2Simp(const std::wstring &sentence, std::wstring &sentence_simp); - - // 生成字典 - int GenDict(const std::string &file, std::map &map); - - // 由 词+词性的分词结果转为仅包含词的结果 - int GetSegResult(std::vector> &seg, std::vector &seg_words); - - // 生成句子的音素,音调id。如果音素和音调未分开,则 toneids 为空(fastspeech2),反之则不为空(speedyspeech) - int GetSentenceIds(const std::string &sentence, std::vector &phoneids, std::vector &toneids); - - // 根据分词结果获取词的音素,音调id,并对读音进行适当修改 (ModifyTone)。如果音素和音调未分开,则 toneids 为空(fastspeech2),反之则不为空(speedyspeech) - int GetWordsIds(const std::vector> &cut_result, std::vector &phoneids, std::vector &toneids); - - // 结巴分词生成包含词和词性的分词结果,再对分词结果进行适当修改 (MergeforModify) - int Cut(const std::string &sentence, std::vector> &cut_result); - - // 字词到音素的映射,查找字典 - int GetPhone(const std::string &word, std::string &phone); - - // 音素到音素id - int Phone2Phoneid(const std::string &phone, std::vector &phoneid, std::vector &toneids); - - - // 根据韵母判断该词中每个字的读音都为第三声。true表示词中每个字都是第三声 - bool AllToneThree(const std::vector &finals); - - // 判断词是否是叠词 - bool IsReduplication(const std::string &word); - - // 获取每个字词的声母韵母列表 - int GetInitialsFinals(const std::string &word, std::vector &word_initials, std::vector &word_finals); - - // 获取每个字词的韵母列表 - int GetFinals(const std::string &word, std::vector &word_finals); - // 整个词转成向量形式,向量的每个元素对应词的一个字 - int Word2WordVec(const std::string &word, std::vector &wordvec); +class FrontEngineInterface : public TextNormalizer { + public: + explicit FrontEngineInterface(std::string conf) : _conf_file(conf) { + TextNormalizer(); + _jieba = nullptr; + _initialed = false; + init(); + } - // 将整个词重新进行 full cut,分词后,各个词会在词典中 - int SplitWord(const std::string &word, std::vector &fullcut_word); - - // 对分词结果进行处理:对包含“不”字的分词结果进行整理 - std::vector> MergeBu(std::vector> &seg_result); + int init(); + ~FrontEngineInterface() {} - // 对分词结果进行处理:对包含“一”字的分词结果进行整理 - std::vector> Mergeyi(std::vector> &seg_result); + // 读取配置文件 + int ReadConfFile(); - // 对分词结果进行处理:对前后相同的两个字进行合并 - std::vector> MergeReduplication(std::vector> &seg_result); + // 简体转繁体 + int Trand2Simp(const std::wstring &sentence, std::wstring *sentence_simp); - // 对一个词和后一个词他们的读音均为第三声的两个词进行合并 - std::vector> MergeThreeTones(std::vector> &seg_result); + // 生成字典 + int GenDict(const std::string &file, + std::map *map); - // 对一个词的最后一个读音和后一个词的第一个读音为第三声的两个词进行合并 - std::vector> MergeThreeTones2(std::vector> &seg_result); + // 由 词+词性的分词结果转为仅包含词的结果 + int GetSegResult(std::vector> *seg, + std::vector *seg_words); - // 对分词结果进行处理:对包含“儿”字的分词结果进行整理 - std::vector> MergeEr(std::vector> &seg_result); + // 生成句子的音素,音调id。如果音素和音调未分开,则 toneids + // 为空(fastspeech2),反之则不为空(speedyspeech) + int GetSentenceIds(const std::string &sentence, + std::vector *phoneids, + std::vector *toneids); - // 对分词结果进行处理、修改 - int MergeforModify(std::vector> &seg_result, std::vector> &merge_seg_result); + // 根据分词结果获取词的音素,音调id,并对读音进行适当修改 + // (ModifyTone)。如果音素和音调未分开,则 toneids + // 为空(fastspeech2),反之则不为空(speedyspeech) + int GetWordsIds( + const std::vector> &cut_result, + std::vector *phoneids, + std::vector *toneids); + // 结巴分词生成包含词和词性的分词结果,再对分词结果进行适当修改 + // (MergeforModify) + int Cut(const std::string &sentence, + std::vector> *cut_result); - // 对包含“不”字的相关词音调进行修改 - int BuSandi(const std::string &word, std::vector &finals); + // 字词到音素的映射,查找字典 + int GetPhone(const std::string &word, std::string *phone); - // 对包含“一”字的相关词音调进行修改 - int YiSandhi(const std::string &word, std::vector &finals); + // 音素到音素id + int Phone2Phoneid(const std::string &phone, + std::vector *phoneid, + std::vector *toneids); - // 对一些特殊词(包括量词,语助词等)的相关词音调进行修改 - int NeuralSandhi(const std::string &word, const std::string &pos, std::vector &finals); - // 对包含第三声的相关词音调进行修改 - int ThreeSandhi(const std::string &word, std::vector &finals); + // 根据韵母判断该词中每个字的读音都为第三声。true表示词中每个字都是第三声 + bool AllToneThree(const std::vector &finals); - // 对字词音调进行处理、修改 - int ModifyTone(const std::string &word, const std::string &pos, std::vector &finals); + // 判断词是否是叠词 + bool IsReduplication(const std::string &word); + + // 获取每个字词的声母韵母列表 + int GetInitialsFinals(const std::string &word, + std::vector *word_initials, + std::vector *word_finals); + // 获取每个字词的韵母列表 + int GetFinals(const std::string &word, + std::vector *word_finals); + + // 整个词转成向量形式,向量的每个元素对应词的一个字 + int Word2WordVec(const std::string &word, + std::vector *wordvec); + + // 将整个词重新进行 full cut,分词后,各个词会在词典中 + int SplitWord(const std::string &word, + std::vector *fullcut_word); + + // 对分词结果进行处理:对包含“不”字的分词结果进行整理 + std::vector> MergeBu( + std::vector> *seg_result); + + // 对分词结果进行处理:对包含“一”字的分词结果进行整理 + std::vector> Mergeyi( + std::vector> *seg_result); + + // 对分词结果进行处理:对前后相同的两个字进行合并 + std::vector> MergeReduplication( + std::vector> *seg_result); + + // 对一个词和后一个词他们的读音均为第三声的两个词进行合并 + std::vector> MergeThreeTones( + std::vector> *seg_result); + + // 对一个词的最后一个读音和后一个词的第一个读音为第三声的两个词进行合并 + std::vector> MergeThreeTones2( + std::vector> *seg_result); + + // 对分词结果进行处理:对包含“儿”字的分词结果进行整理 + std::vector> MergeEr( + std::vector> *seg_result); + + // 对分词结果进行处理、修改 + int MergeforModify( + std::vector> *seg_result, + std::vector> *merge_seg_result); - // 对儿化音进行处理 - std::vector> MergeErhua(const std::vector &initials, const std::vector &finals, const std::string &word, const std::string &pos); - + // 对包含“不”字的相关词音调进行修改 + int BuSandi(const std::string &word, std::vector *finals); - private: - bool _initialed; - cppjieba::Jieba *_jieba; - std::vector _punc; - std::vector _punc_omit; + // 对包含“一”字的相关词音调进行修改 + int YiSandhi(const std::string &word, std::vector *finals); + + // 对一些特殊词(包括量词,语助词等)的相关词音调进行修改 + int NeuralSandhi(const std::string &word, + const std::string &pos, + std::vector *finals); - std::string _conf_file; - std::map conf_map; - std::map word_phone_map; - std::map phone_id_map; - std::map tone_id_map; - std::map trand_simp_map; + // 对包含第三声的相关词音调进行修改 + int ThreeSandhi(const std::string &word, std::vector *finals); + + // 对字词音调进行处理、修改 + int ModifyTone(const std::string &word, + const std::string &pos, + std::vector *finals); - std::string _jieba_dict_path; - std::string _jieba_hmm_path; - std::string _jieba_user_dict_path; - std::string _jieba_idf_path; - std::string _jieba_stop_word_path; + // 对儿化音进行处理 + std::vector> MergeErhua( + const std::vector &initials, + const std::vector &finals, + const std::string &word, + const std::string &pos); + - std::string _seperate_tone; - std::string _word2phone_path; - std::string _phone2id_path; - std::string _tone2id_path; - std::string _trand2simp_path; + private: + bool _initialed; + cppjieba::Jieba *_jieba; + std::vector _punc; + std::vector _punc_omit; - std::vector must_erhua; - std::vector not_erhua; + std::string _conf_file; + std::map conf_map; + std::map word_phone_map; + std::map phone_id_map; + std::map tone_id_map; + std::map trand_simp_map; - std::vector must_not_neural_tone_words; - std::vector must_neural_tone_words; + std::string _jieba_dict_path; + std::string _jieba_hmm_path; + std::string _jieba_user_dict_path; + std::string _jieba_idf_path; + std::string _jieba_stop_word_path; + std::string _seperate_tone; + std::string _word2phone_path; + std::string _phone2id_path; + std::string _tone2id_path; + std::string _trand2simp_path; + + std::vector must_erhua; + std::vector not_erhua; - }; -} + std::vector must_not_neural_tone_words; + std::vector must_neural_tone_words; +}; +} // namespace ppspeech #endif \ No newline at end of file diff --git a/demos/TTSCppFrontend/src/front/text_normalize.cpp b/demos/TTSCppFrontend/src/front/text_normalize.cpp index 11a493ba9..8420e8407 100644 --- a/demos/TTSCppFrontend/src/front/text_normalize.cpp +++ b/demos/TTSCppFrontend/src/front/text_normalize.cpp @@ -1,10 +1,22 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "front/text_normalize.h" namespace ppspeech { // 初始化 digits_map and unit_map int TextNormalizer::InitMap() { - digits_map["0"] = "零"; digits_map["1"] = "一"; digits_map["2"] = "二"; @@ -21,77 +33,84 @@ int TextNormalizer::InitMap() { units_map[3] = "千"; units_map[4] = "万"; units_map[8] = "亿"; - + return 0; } // 替换 -int TextNormalizer::Replace(std::wstring &sentence, const int &pos, const int &len, const std::wstring &repstr) { +int TextNormalizer::Replace(std::wstring *sentence, + const int &pos, + const int &len, + const std::wstring &repstr) { // 删除原来的 - sentence.erase(pos, len); + sentence->erase(pos, len); // 插入新的 - sentence.insert(pos, repstr); + sentence->insert(pos, repstr); return 0; - } // 根据标点符号切分句子 -int TextNormalizer::SplitByPunc(const std::wstring &sentence, std::vector &sentence_part) { +int TextNormalizer::SplitByPunc(const std::wstring &sentence, + std::vector *sentence_part) { std::wstring temp = sentence; std::wregex reg(L"[:,;。?!,;?!]"); std::wsmatch match; - while (std::regex_search (temp, match, reg)) { - sentence_part.push_back(temp.substr(0, match.position(0) + match.length(0))); - Replace(temp, 0, match.position(0) + match.length(0), L""); + while (std::regex_search(temp, match, reg)) { + sentence_part->push_back( + temp.substr(0, match.position(0) + match.length(0))); + Replace(&temp, 0, match.position(0) + match.length(0), L""); } // 如果最后没有标点符号 - if(temp != L"") { - sentence_part.push_back(temp); + if (temp != L"") { + sentence_part->push_back(temp); } return 0; } -//数字转文本,10200 - > 一万零二百 -std::string TextNormalizer::CreateTextValue(const std::string &num_str, bool use_zero) { - - std::string num_lstrip = std::string(absl::StripPrefix(num_str, "0")).data(); +// 数字转文本,10200 - > 一万零二百 +std::string TextNormalizer::CreateTextValue(const std::string &num_str, + bool use_zero) { + std::string num_lstrip = + std::string(absl::StripPrefix(num_str, "0")).data(); int len = num_lstrip.length(); - - if(len == 0) { + + if (len == 0) { return ""; } else if (len == 1) { - if(use_zero && (len < num_str.length())) { + if (use_zero && (len < num_str.length())) { return digits_map["0"] + digits_map[num_lstrip]; } else { return digits_map[num_lstrip]; } } else { - int largest_unit = 0; // 最大单位 + int largest_unit = 0; // 最大单位 std::string first_part; std::string second_part; - if (len > 1 and len <= 2) { + if (len > 1 && len <= 2) { largest_unit = 1; - } else if (len > 2 and len <= 3) { + } else if (len > 2 && len <= 3) { largest_unit = 2; - } else if (len > 3 and len <= 4) { + } else if (len > 3 && len <= 4) { largest_unit = 3; - } else if (len > 4 and len <= 8) { + } else if (len > 4 && len <= 8) { largest_unit = 4; } else if (len > 8) { - largest_unit = 8; - } + largest_unit = 8; + } first_part = num_str.substr(0, num_str.length() - largest_unit); second_part = num_str.substr(num_str.length() - largest_unit); - - return CreateTextValue(first_part, use_zero) + units_map[largest_unit] + CreateTextValue(second_part, use_zero); + + return CreateTextValue(first_part, use_zero) + units_map[largest_unit] + + CreateTextValue(second_part, use_zero); } } -// 数字一个一个对应,可直接用于年份,电话,手机, -std::string TextNormalizer::SingleDigit2Text(const std::string &num_str, bool alt_one) { +// 数字一个一个对应,可直接用于年份,电话,手机, +std::string TextNormalizer::SingleDigit2Text(const std::string &num_str, + bool alt_one) { std::string text = ""; if (alt_one) { digits_map["1"] = "幺"; @@ -110,13 +129,16 @@ std::string TextNormalizer::SingleDigit2Text(const std::string &num_str, bool al return text; } -std::string TextNormalizer::SingleDigit2Text(const std::wstring &num, bool alt_one) { +std::string TextNormalizer::SingleDigit2Text(const std::wstring &num, + bool alt_one) { std::string num_str = wstring2utf8string(num); return SingleDigit2Text(num_str, alt_one); } // 数字整体对应,可直接用于月份,日期,数值整数部分 -std::string TextNormalizer::MultiDigit2Text(const std::string &num_str, bool alt_one, bool use_zero) { +std::string TextNormalizer::MultiDigit2Text(const std::string &num_str, + bool alt_one, + bool use_zero) { LOG(INFO) << "aaaaaaaaaaaaaaaa: " << alt_one << use_zero; if (alt_one) { digits_map["1"] = "幺"; @@ -124,18 +146,22 @@ std::string TextNormalizer::MultiDigit2Text(const std::string &num_str, bool alt digits_map["1"] = "一"; } - std::wstring result = utf8string2wstring(CreateTextValue(num_str, use_zero)); + std::wstring result = + utf8string2wstring(CreateTextValue(num_str, use_zero)); std::wstring result_0(1, result[0]); std::wstring result_1(1, result[1]); // 一十八 --> 十八 - if ((result_0 == utf8string2wstring(digits_map["1"])) && (result_1 == utf8string2wstring(units_map[1]))) { - return wstring2utf8string(result.substr(1,result.length())); + if ((result_0 == utf8string2wstring(digits_map["1"])) && + (result_1 == utf8string2wstring(units_map[1]))) { + return wstring2utf8string(result.substr(1, result.length())); } else { return wstring2utf8string(result); } } -std::string TextNormalizer::MultiDigit2Text(const std::wstring &num, bool alt_one, bool use_zero) { +std::string TextNormalizer::MultiDigit2Text(const std::wstring &num, + bool alt_one, + bool use_zero) { std::string num_str = wstring2utf8string(num); return MultiDigit2Text(num_str, alt_one, use_zero); } @@ -145,15 +171,20 @@ std::string TextNormalizer::Digits2Text(const std::string &num_str) { std::string text; std::vector integer_decimal; integer_decimal = absl::StrSplit(num_str, "."); - - if(integer_decimal.size() == 1) { // 整数 + + if (integer_decimal.size() == 1) { // 整数 text = MultiDigit2Text(integer_decimal[0]); - } else if(integer_decimal.size() == 2) { // 小数 - if(integer_decimal[0] == "") { // 无整数的小数类型,例如:.22 - text = "点" + SingleDigit2Text(std::string(absl::StripSuffix(integer_decimal[1], "0")).data()); + } else if (integer_decimal.size() == 2) { // 小数 + if (integer_decimal[0] == "") { // 无整数的小数类型,例如:.22 + text = "点" + + SingleDigit2Text( + std::string(absl::StripSuffix(integer_decimal[1], "0")) + .data()); } else { // 常规小数类型,例如:12.34 - text = MultiDigit2Text(integer_decimal[0]) + "点" + \ - SingleDigit2Text(std::string(absl::StripSuffix(integer_decimal[1], "0")).data()); + text = MultiDigit2Text(integer_decimal[0]) + "点" + + SingleDigit2Text( + std::string(absl::StripSuffix(integer_decimal[1], "0")) + .data()); } } else { return "The value does not conform to the numeric format"; @@ -168,23 +199,28 @@ std::string TextNormalizer::Digits2Text(const std::wstring &num) { } // 日期,2021年8月18日 --> 二零二一年八月十八日 -int TextNormalizer::ReData(std::wstring &sentence) { - std::wregex reg(L"(\\d{4}|\\d{2})年((0?[1-9]|1[0-2])月)?(((0?[1-9])|((1|2)[0-9])|30|31)([日号]))?"); +int TextNormalizer::ReData(std::wstring *sentence) { + std::wregex reg( + L"(\\d{4}|\\d{2})年((0?[1-9]|1[0-2])月)?(((0?[1-9])|((1|2)[0-9])|30|31)" + L"([日号]))?"); std::wsmatch match; std::string rep; - while (std::regex_search (sentence, match, reg)) { + while (std::regex_search(*sentence, match, reg)) { rep = ""; rep += SingleDigit2Text(match[1]) + "年"; - if(match[3] != L"") { + if (match[3] != L"") { rep += MultiDigit2Text(match[3], false, false) + "月"; } - if(match[5] != L"") { - rep += MultiDigit2Text(match[5], false, false) + wstring2utf8string(match[9]); + if (match[5] != L"") { + rep += MultiDigit2Text(match[5], false, false) + + wstring2utf8string(match[9]); } - Replace(sentence, match.position(0), match.length(0), utf8string2wstring(rep)); - + Replace(sentence, + match.position(0), + match.length(0), + utf8string2wstring(rep)); } return 0; @@ -192,255 +228,301 @@ int TextNormalizer::ReData(std::wstring &sentence) { // XX-XX-XX or XX/XX/XX 例如:2021/08/18 --> 二零二一年八月十八日 -int TextNormalizer::ReData2(std::wstring &sentence) { - std::wregex reg(L"(\\d{4})([- /.])(0[1-9]|1[012])\\2(0[1-9]|[12][0-9]|3[01])"); +int TextNormalizer::ReData2(std::wstring *sentence) { + std::wregex reg( + L"(\\d{4})([- /.])(0[1-9]|1[012])\\2(0[1-9]|[12][0-9]|3[01])"); std::wsmatch match; std::string rep; - - while (std::regex_search (sentence, match, reg)) { + + while (std::regex_search(*sentence, match, reg)) { rep = ""; rep += (SingleDigit2Text(match[1]) + "年"); rep += (MultiDigit2Text(match[3], false, false) + "月"); rep += (MultiDigit2Text(match[4], false, false) + "日"); - Replace(sentence, match.position(0), match.length(0), utf8string2wstring(rep)); - + Replace(sentence, + match.position(0), + match.length(0), + utf8string2wstring(rep)); } - + return 0; } // XX:XX:XX 09:09:02 --> 九点零九分零二秒 -int TextNormalizer::ReTime(std::wstring &sentence) { +int TextNormalizer::ReTime(std::wstring *sentence) { std::wregex reg(L"([0-1]?[0-9]|2[0-3]):([0-5][0-9])(:([0-5][0-9]))?"); std::wsmatch match; std::string rep; - - while (std::regex_search (sentence, match, reg)) { + + while (std::regex_search(*sentence, match, reg)) { rep = ""; rep += (MultiDigit2Text(match[1], false, false) + "点"); - if(absl::StartsWith(wstring2utf8string(match[2]), "0")) { + if (absl::StartsWith(wstring2utf8string(match[2]), "0")) { rep += "零"; } rep += (MultiDigit2Text(match[2]) + "分"); - if(absl::StartsWith(wstring2utf8string(match[4]), "0")) { + if (absl::StartsWith(wstring2utf8string(match[4]), "0")) { rep += "零"; } rep += (MultiDigit2Text(match[4]) + "秒"); - Replace(sentence, match.position(0), match.length(0), utf8string2wstring(rep)); + Replace(sentence, + match.position(0), + match.length(0), + utf8string2wstring(rep)); } return 0; } // 温度,例如:-24.3℃ --> 零下二十四点三度 -int TextNormalizer::ReTemperature(std::wstring &sentence) { - std::wregex reg(L"(-?)(\\d+(\\.\\d+)?)(°C|℃|度|摄氏度)"); +int TextNormalizer::ReTemperature(std::wstring *sentence) { + std::wregex reg(L"(-?)(\\d+(\\.\\d+)?)(°C|℃|度|摄氏度)"); std::wsmatch match; std::string rep; std::string sign; std::vector integer_decimal; std::string unit; - while (std::regex_search (sentence, match, reg)) { + while (std::regex_search(*sentence, match, reg)) { match[1] == L"-" ? sign = "负" : sign = ""; - match[4] == L"摄氏度"? unit = "摄氏度" : unit = "度"; + match[4] == L"摄氏度" ? unit = "摄氏度" : unit = "度"; rep = sign + Digits2Text(match[2]) + unit; - - Replace(sentence, match.position(0), match.length(0), utf8string2wstring(rep)); + Replace(sentence, + match.position(0), + match.length(0), + utf8string2wstring(rep)); } return 0; - } // 分数,例如: 1/3 --> 三分之一 -int TextNormalizer::ReFrac(std::wstring &sentence) { - std::wregex reg(L"(-?)(\\d+)/(\\d+)"); +int TextNormalizer::ReFrac(std::wstring *sentence) { + std::wregex reg(L"(-?)(\\d+)/(\\d+)"); std::wsmatch match; std::string sign; std::string rep; - while (std::regex_search (sentence, match, reg)) { + while (std::regex_search(*sentence, match, reg)) { match[1] == L"-" ? sign = "负" : sign = ""; - rep = sign + MultiDigit2Text(match[3]) + "分之" + MultiDigit2Text(match[2]); - Replace(sentence, match.position(0), match.length(0), utf8string2wstring(rep)); + rep = sign + MultiDigit2Text(match[3]) + "分之" + + MultiDigit2Text(match[2]); + Replace(sentence, + match.position(0), + match.length(0), + utf8string2wstring(rep)); } return 0; } // 百分数,例如:45.5% --> 百分之四十五点五 -int TextNormalizer::RePercentage(std::wstring &sentence) { - std::wregex reg(L"(-?)(\\d+(\\.\\d+)?)%"); +int TextNormalizer::RePercentage(std::wstring *sentence) { + std::wregex reg(L"(-?)(\\d+(\\.\\d+)?)%"); std::wsmatch match; std::string sign; std::string rep; std::vector integer_decimal; - - while (std::regex_search (sentence, match, reg)) { + + while (std::regex_search(*sentence, match, reg)) { match[1] == L"-" ? sign = "负" : sign = ""; rep = sign + "百分之" + Digits2Text(match[2]); - - Replace(sentence, match.position(0), match.length(0), utf8string2wstring(rep)); + + Replace(sentence, + match.position(0), + match.length(0), + utf8string2wstring(rep)); } - + return 0; } // 手机号码,例如:+86 18883862235 --> 八六幺八八八三八六二二三五 -int TextNormalizer::ReMobilePhone(std::wstring &sentence) { - std::wregex reg(L"(\\d)?((\\+?86 ?)?1([38]\\d|5[0-35-9]|7[678]|9[89])\\d{8})(\\d)?"); +int TextNormalizer::ReMobilePhone(std::wstring *sentence) { + std::wregex reg( + L"(\\d)?((\\+?86 ?)?1([38]\\d|5[0-35-9]|7[678]|9[89])\\d{8})(\\d)?"); std::wsmatch match; std::string rep; std::vector country_phonenum; - while (std::regex_search (sentence, match, reg)) { + while (std::regex_search(*sentence, match, reg)) { country_phonenum = absl::StrSplit(wstring2utf8string(match[0]), "+"); rep = ""; - for(int i = 0; i < country_phonenum.size(); i++) { + for (int i = 0; i < country_phonenum.size(); i++) { LOG(INFO) << country_phonenum[i]; rep += SingleDigit2Text(country_phonenum[i], true); } - Replace(sentence, match.position(0), match.length(0), utf8string2wstring(rep)); - + Replace(sentence, + match.position(0), + match.length(0), + utf8string2wstring(rep)); } - + return 0; } // 座机号码,例如:010-51093154 --> 零幺零五幺零九三幺五四 -int TextNormalizer::RePhone(std::wstring &sentence) { - std::wregex reg(L"(\\d)?((0(10|2[1-3]|[3-9]\\d{2})-?)?[1-9]\\d{6,7})(\\d)?"); +int TextNormalizer::RePhone(std::wstring *sentence) { + std::wregex reg( + L"(\\d)?((0(10|2[1-3]|[3-9]\\d{2})-?)?[1-9]\\d{6,7})(\\d)?"); std::wsmatch match; std::vector zone_phonenum; std::string rep; - while (std::regex_search (sentence, match, reg)) { + while (std::regex_search(*sentence, match, reg)) { rep = ""; zone_phonenum = absl::StrSplit(wstring2utf8string(match[0]), "-"); - for(int i = 0; i < zone_phonenum.size(); i ++) { + for (int i = 0; i < zone_phonenum.size(); i++) { rep += SingleDigit2Text(zone_phonenum[i], true); } - Replace(sentence, match.position(0), match.length(0), utf8string2wstring(rep)); + Replace(sentence, + match.position(0), + match.length(0), + utf8string2wstring(rep)); } return 0; } // 范围,例如:60~90 --> 六十到九十 -int TextNormalizer::ReRange(std::wstring &sentence) { - std::wregex reg(L"((-?)((\\d+)(\\.\\d+)?)|(\\.(\\d+)))[-~]((-?)((\\d+)(\\.\\d+)?)|(\\.(\\d+)))"); +int TextNormalizer::ReRange(std::wstring *sentence) { + std::wregex reg( + L"((-?)((\\d+)(\\.\\d+)?)|(\\.(\\d+)))[-~]((-?)((\\d+)(\\.\\d+)?)|(\\.(" + L"\\d+)))"); std::wsmatch match; std::string rep; std::string sign1; std::string sign2; - while (std::regex_search (sentence, match, reg)) { + while (std::regex_search(*sentence, match, reg)) { rep = ""; match[2] == L"-" ? sign1 = "负" : sign1 = ""; - if(match[6] != L"") { + if (match[6] != L"") { rep += sign1 + Digits2Text(match[6]) + "到"; } else { rep += sign1 + Digits2Text(match[3]) + "到"; } match[9] == L"-" ? sign2 = "负" : sign2 = ""; - if(match[13] != L"") { + if (match[13] != L"") { rep += sign2 + Digits2Text(match[13]); } else { rep += sign2 + Digits2Text(match[10]); } - Replace(sentence, match.position(0), match.length(0), utf8string2wstring(rep)); + Replace(sentence, + match.position(0), + match.length(0), + utf8string2wstring(rep)); } return 0; } // 带负号的整数,例如:-10 --> 负十 -int TextNormalizer::ReInterger(std::wstring &sentence) { - std::wregex reg(L"(-)(\\d+)"); +int TextNormalizer::ReInterger(std::wstring *sentence) { + std::wregex reg(L"(-)(\\d+)"); std::wsmatch match; std::string rep; - while (std::regex_search (sentence, match, reg)) { + while (std::regex_search(*sentence, match, reg)) { rep = "负" + MultiDigit2Text(match[2]); - Replace(sentence, match.position(0), match.length(0), utf8string2wstring(rep)); + Replace(sentence, + match.position(0), + match.length(0), + utf8string2wstring(rep)); } - + return 0; } // 纯小数 -int TextNormalizer::ReDecimalNum(std::wstring &sentence) { - std::wregex reg(L"(-?)((\\d+)(\\.\\d+))|(\\.(\\d+))"); +int TextNormalizer::ReDecimalNum(std::wstring *sentence) { + std::wregex reg(L"(-?)((\\d+)(\\.\\d+))|(\\.(\\d+))"); std::wsmatch match; std::string sign; std::string rep; - //std::vector integer_decimal; - while (std::regex_search (sentence, match, reg)) { + // std::vector integer_decimal; + while (std::regex_search(*sentence, match, reg)) { match[1] == L"-" ? sign = "负" : sign = ""; - if(match[5] != L"") { + if (match[5] != L"") { rep = sign + Digits2Text(match[5]); } else { rep = sign + Digits2Text(match[2]); } - Replace(sentence, match.position(0), match.length(0), utf8string2wstring(rep)); + Replace(sentence, + match.position(0), + match.length(0), + utf8string2wstring(rep)); } return 0; } // 正整数 + 量词 -int TextNormalizer::RePositiveQuantifiers(std::wstring &sentence) { - std::wstring common_quantifiers = L"(朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲| \ - 墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂| \ - 课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘| \ - 毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日| \ - 季|刻|时|周|天|秒|分|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万| \ - 万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"; - std::wregex reg(L"(\\d+)([多余几])?" + common_quantifiers); +int TextNormalizer::RePositiveQuantifiers(std::wstring *sentence) { + std::wstring common_quantifiers = + L"(朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|" + L"担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|" + L"溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|" + L"本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|" + L"毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|" + L"合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|" + L"卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|纪|岁|世|更|" + L"夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|" + L"元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|" + L"百万|万|千|百|)块|角|毛|分)"; + std::wregex reg(L"(\\d+)([多余几])?" + common_quantifiers); std::wsmatch match; std::string rep; - while (std::regex_search (sentence, match, reg)) { + while (std::regex_search(*sentence, match, reg)) { rep = MultiDigit2Text(match[1]); - Replace(sentence, match.position(1), match.length(1), utf8string2wstring(rep)); + Replace(sentence, + match.position(1), + match.length(1), + utf8string2wstring(rep)); } return 0; } // 编号类数字,例如: 89757 --> 八九七五七 -int TextNormalizer::ReDefalutNum(std::wstring &sentence) { - std::wregex reg(L"\\d{3}\\d*"); +int TextNormalizer::ReDefalutNum(std::wstring *sentence) { + std::wregex reg(L"\\d{3}\\d*"); std::wsmatch match; - while (std::regex_search (sentence, match, reg)) { - Replace(sentence, match.position(0), match.length(0), utf8string2wstring(SingleDigit2Text(match[0]))); + while (std::regex_search(*sentence, match, reg)) { + Replace(sentence, + match.position(0), + match.length(0), + utf8string2wstring(SingleDigit2Text(match[0]))); } return 0; } -int TextNormalizer::ReNumber(std::wstring &sentence) { - std::wregex reg(L"(-?)((\\d+)(\\.\\d+)?)|(\\.(\\d+))"); +int TextNormalizer::ReNumber(std::wstring *sentence) { + std::wregex reg(L"(-?)((\\d+)(\\.\\d+)?)|(\\.(\\d+))"); std::wsmatch match; std::string sign; std::string rep; - while (std::regex_search (sentence, match, reg)) { + while (std::regex_search(*sentence, match, reg)) { match[1] == L"-" ? sign = "负" : sign = ""; - if(match[5] != L"") { + if (match[5] != L"") { rep = sign + Digits2Text(match[5]); } else { rep = sign + Digits2Text(match[2]); } - - Replace(sentence, match.position(0), match.length(0), utf8string2wstring(rep)); + + Replace(sentence, + match.position(0), + match.length(0), + utf8string2wstring(rep)); } return 0; } // 整体正则,按顺序 -int TextNormalizer::SentenceNormalize(std::wstring &sentence) { +int TextNormalizer::SentenceNormalize(std::wstring *sentence) { ReData(sentence); ReData2(sentence); ReTime(sentence); @@ -452,11 +534,9 @@ int TextNormalizer::SentenceNormalize(std::wstring &sentence) { ReRange(sentence); ReInterger(sentence); ReDecimalNum(sentence); - RePositiveQuantifiers(sentence); + RePositiveQuantifiers(sentence); ReDefalutNum(sentence); ReNumber(sentence); - return 0; + return 0; } - - -} \ No newline at end of file +} // namespace ppspeech \ No newline at end of file diff --git a/demos/TTSCppFrontend/src/front/text_normalize.h b/demos/TTSCppFrontend/src/front/text_normalize.h index 20d502b82..4383fa1b4 100644 --- a/demos/TTSCppFrontend/src/front/text_normalize.h +++ b/demos/TTSCppFrontend/src/front/text_normalize.h @@ -1,11 +1,24 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #ifndef PADDLE_TTS_SERVING_FRONT_TEXT_NORMALIZE_H #define PADDLE_TTS_SERVING_FRONT_TEXT_NORMALIZE_H +#include +#include #include #include #include -#include -#include #include "absl/strings/str_split.h" #include "absl/strings/strip.h" #include "base/type_conv.h" @@ -13,50 +26,52 @@ namespace ppspeech { class TextNormalizer { -public: - TextNormalizer() { - InitMap(); - } - ~TextNormalizer() { - - } + public: + TextNormalizer() { InitMap(); } + ~TextNormalizer() {} int InitMap(); - int Replace(std::wstring &sentence, const int &pos, const int &len, const std::wstring &repstr); - int SplitByPunc(const std::wstring &sentence, std::vector &sentence_part); + int Replace(std::wstring *sentence, + const int &pos, + const int &len, + const std::wstring &repstr); + int SplitByPunc(const std::wstring &sentence, + std::vector *sentence_part); - std::string CreateTextValue(const std::string &num, bool use_zero=true); - std::string SingleDigit2Text(const std::string &num_str, bool alt_one = false); + std::string CreateTextValue(const std::string &num, bool use_zero = true); + std::string SingleDigit2Text(const std::string &num_str, + bool alt_one = false); std::string SingleDigit2Text(const std::wstring &num, bool alt_one = false); - std::string MultiDigit2Text(const std::string &num_str, bool alt_one = false, bool use_zero = true); - std::string MultiDigit2Text(const std::wstring &num, bool alt_one = false, bool use_zero = true); + std::string MultiDigit2Text(const std::string &num_str, + bool alt_one = false, + bool use_zero = true); + std::string MultiDigit2Text(const std::wstring &num, + bool alt_one = false, + bool use_zero = true); std::string Digits2Text(const std::string &num_str); std::string Digits2Text(const std::wstring &num); - int ReData(std::wstring &sentence); - int ReData2(std::wstring &sentence); - int ReTime(std::wstring &sentence); - int ReTemperature(std::wstring &sentence); - int ReFrac(std::wstring &sentence); - int RePercentage(std::wstring &sentence); - int ReMobilePhone(std::wstring &sentence); - int RePhone(std::wstring &sentence); - int ReRange(std::wstring &sentence); - int ReInterger(std::wstring &sentence); - int ReDecimalNum(std::wstring &sentence); - int RePositiveQuantifiers(std::wstring &sentence); - int ReDefalutNum(std::wstring &sentence); - int ReNumber(std::wstring &sentence); - int SentenceNormalize(std::wstring &sentence); - - -private: - std::map digits_map; - std::map units_map; + int ReData(std::wstring *sentence); + int ReData2(std::wstring *sentence); + int ReTime(std::wstring *sentence); + int ReTemperature(std::wstring *sentence); + int ReFrac(std::wstring *sentence); + int RePercentage(std::wstring *sentence); + int ReMobilePhone(std::wstring *sentence); + int RePhone(std::wstring *sentence); + int ReRange(std::wstring *sentence); + int ReInterger(std::wstring *sentence); + int ReDecimalNum(std::wstring *sentence); + int RePositiveQuantifiers(std::wstring *sentence); + int ReDefalutNum(std::wstring *sentence); + int ReNumber(std::wstring *sentence); + int SentenceNormalize(std::wstring *sentence); + private: + std::map digits_map; + std::map units_map; }; - -} +} // namespace ppspeech #endif \ No newline at end of file diff --git a/examples/opencpop/voc5/conf/finetune.yaml b/examples/opencpop/voc5/conf/finetune.yaml index 8e66b4e60..0022a67aa 100644 --- a/examples/opencpop/voc5/conf/finetune.yaml +++ b/examples/opencpop/voc5/conf/finetune.yaml @@ -157,7 +157,7 @@ discriminator_grad_norm: -1 # Discriminator's gradient norm. ########################################################### generator_train_start_steps: 1 # Number of steps to start to train discriminator. discriminator_train_start_steps: 0 # Number of steps to start to train discriminator. -train_max_steps: 650000 # Number of training steps. +train_max_steps: 2600000 # Number of training steps. save_interval_steps: 5000 # Interval steps to save checkpoint. eval_interval_steps: 1000 # Interval steps to evaluate the network. diff --git a/examples/other/tn/data/textnorm_test_cases.txt b/examples/other/tn/data/textnorm_test_cases.txt index 17e90d0b6..ba9e6529a 100644 --- a/examples/other/tn/data/textnorm_test_cases.txt +++ b/examples/other/tn/data/textnorm_test_cases.txt @@ -32,7 +32,7 @@ iPad Pro的秒控键盘这次也推出白色版本。|iPad Pro的秒控键盘这 明天有62%的概率降雨|明天有百分之六十二的概率降雨 这是固话0421-33441122|这是固话零四二一三三四四一一二二 这是手机+86 18544139121|这是手机八六一八五四四一三九一二一 -小王的身高是153.5cm,梦想是打篮球!我觉得有0.1%的可能性。|小王的身高是一百五十三点五cm,梦想是打篮球!我觉得有百分之零点一的可能性。 +小王的身高是153.5cm,梦想是打篮球!我觉得有0.1%的可能性。|小王的身高是一百五十三点五厘米,梦想是打篮球!我觉得有百分之零点一的可能性。 不管三七二十一|不管三七二十一 九九八十一难|九九八十一难 2018年5月23号上午10点10分|二零一八年五月二十三号上午十点十分 @@ -124,4 +124,4 @@ iPad Pro的秒控键盘这次也推出白色版本。|iPad Pro的秒控键盘这 12~23|十二到二十三 12-23|十二到二十三 25cm²|二十五平方厘米 -25m|米 \ No newline at end of file +25m|米 diff --git a/paddlespeech/t2s/models/vits/duration_predictor.py b/paddlespeech/t2s/models/vits/duration_predictor.py index b0bb68d0f..12177fbc2 100644 --- a/paddlespeech/t2s/models/vits/duration_predictor.py +++ b/paddlespeech/t2s/models/vits/duration_predictor.py @@ -155,12 +155,10 @@ class StochasticDurationPredictor(nn.Layer): z_u, z1 = paddle.split(z_q, [1, 1], 1) u = F.sigmoid(z_u) * x_mask z0 = (w - u) * x_mask - logdet_tot_q += paddle.sum( - (F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask, [1, 2]) - logq = (paddle.sum(-0.5 * - (math.log(2 * math.pi) + - (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q) - + tmp1 = (F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask + logdet_tot_q += paddle.sum(tmp1, [1, 2]) + tmp2 = -0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask + logq = (paddle.sum(tmp2, [1, 2]) - logdet_tot_q) logdet_tot = 0 z0, logdet = self.log_flow(z0, x_mask) logdet_tot += logdet @@ -168,8 +166,8 @@ class StochasticDurationPredictor(nn.Layer): for flow in self.flows: z, logdet = flow(z, x_mask, g=x, inverse=inverse) logdet_tot = logdet_tot + logdet - nll = (paddle.sum(0.5 * (math.log(2 * math.pi) + - (z**2)) * x_mask, [1, 2]) - logdet_tot) + tmp3 = 0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask + nll = (paddle.sum(tmp3, [1, 2]) - logdet_tot) # (B,) return nll + logq else: diff --git a/paddlespeech/t2s/models/vits/generator.py b/paddlespeech/t2s/models/vits/generator.py index fbd2d6653..44bd78984 100644 --- a/paddlespeech/t2s/models/vits/generator.py +++ b/paddlespeech/t2s/models/vits/generator.py @@ -371,8 +371,9 @@ class VITSGenerator(nn.Layer): # (B, H, T_text) s_p_sq_r = paddle.exp(-2 * logs_p) # (B, 1, T_text) + tmp1 = -0.5 * math.log(2 * math.pi) - logs_p neg_x_ent_1 = paddle.sum( - -0.5 * math.log(2 * math.pi) - logs_p, + tmp1, [1], keepdim=True, ) # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) @@ -384,8 +385,9 @@ class VITSGenerator(nn.Layer): z_p.transpose([0, 2, 1]), (m_p * s_p_sq_r), ) # (B, 1, T_text) + tmp2 = -0.5 * (m_p**2) * s_p_sq_r neg_x_ent_4 = paddle.sum( - -0.5 * (m_p**2) * s_p_sq_r, + tmp2, [1], keepdim=True, ) # (B, T_feats, T_text) @@ -403,7 +405,6 @@ class VITSGenerator(nn.Layer): w = attn.sum(2) dur_nll = self.duration_predictor(x, x_mask, w=w, g=g) dur_nll = dur_nll / paddle.sum(x_mask) - # expand the length to match with the feature sequence # (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats) m_p = paddle.matmul(attn.squeeze(1), @@ -511,8 +512,9 @@ class VITSGenerator(nn.Layer): # (B, H, T_text) s_p_sq_r = paddle.exp(-2 * logs_p) # (B, 1, T_text) + tmp3 = -0.5 * math.log(2 * math.pi) - logs_p neg_x_ent_1 = paddle.sum( - -0.5 * math.log(2 * math.pi) - logs_p, + tmp3, [1], keepdim=True, ) # (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text) @@ -524,8 +526,9 @@ class VITSGenerator(nn.Layer): z_p.transpose([0, 2, 1]), (m_p * s_p_sq_r), ) # (B, 1, T_text) + tmp4 = -0.5 * (m_p**2) * s_p_sq_r neg_x_ent_4 = paddle.sum( - -0.5 * (m_p**2) * s_p_sq_r, + tmp4, [1], keepdim=True, ) # (B, T_feats, T_text) diff --git a/paddlespeech/t2s/models/vits/transform.py b/paddlespeech/t2s/models/vits/transform.py index 61bd5ee2b..0edc1d09d 100644 --- a/paddlespeech/t2s/models/vits/transform.py +++ b/paddlespeech/t2s/models/vits/transform.py @@ -61,8 +61,12 @@ def piecewise_rational_quadratic_transform( def mask_preprocess(x, mask): + # bins.dtype = int32 B, C, T, bins = paddle.shape(x) - new_x = paddle.zeros([mask.sum(), bins]) + mask_int = paddle.cast(mask, dtype='int64') + # paddle.sum 输入是 int32 或 bool 的时候,输出是 int64 + # paddle.zeros (fill_constant) 的 shape 会被强制转成 int32 类型 + new_x = paddle.zeros([paddle.sum(mask_int), bins]) for i in range(bins): new_x[:, i] = x[:, :, :, i][mask] return new_x @@ -240,4 +244,7 @@ def rational_quadratic_spline( def _searchsorted(bin_locations, inputs, eps=1e-6): bin_locations[..., -1] += eps - return paddle.sum(inputs[..., None] >= bin_locations, axis=-1) - 1 + mask = inputs[..., None] >= bin_locations + mask_int = paddle.cast(mask, 'int64') + out = paddle.sum(mask_int, axis=-1) - 1 + return out