diff --git a/paddlespeech/t2s/frontend/mix_frontend.py b/paddlespeech/t2s/frontend/mix_frontend.py index 65b47ddd8..19c98d53f 100644 --- a/paddlespeech/t2s/frontend/mix_frontend.py +++ b/paddlespeech/t2s/frontend/mix_frontend.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re from typing import Dict from typing import List @@ -30,7 +29,6 @@ class MixFrontend(): self.zh_frontend = Frontend( phone_vocab_path=phone_vocab_path, tone_vocab_path=tone_vocab_path) self.en_frontend = English(phone_vocab_path=phone_vocab_path) - self.SENTENCE_SPLITOR = re.compile(r'([:、,;。?!,;?!][”’]?)') self.sp_id = self.zh_frontend.vocab_phones["sp"] self.sp_id_tensor = paddle.to_tensor([self.sp_id]) @@ -47,114 +45,56 @@ class MixFrontend(): else: return False - def is_number(self, char): - if char >= '\u0030' and char <= '\u0039': - return True - else: - return False - def is_other(self, char): - if not (self.is_chinese(char) or self.is_number(char) or - self.is_alphabet(char)): + if not (self.is_chinese(char) or self.is_alphabet(char)): return True else: return False - def _replace(self, text: str) -> str: - new_text = text - - # get "." indexs - point_indexs = [] - index = -1 - for i in range(text.count(".")): - index = text.find(".", index + 1, len(text)) - point_indexs.append(index) - - # replace - if len(point_indexs) != 0: - for index in point_indexs: - ch = text[index - 1] - if self.is_alphabet(ch) or ch == " ": - new_text = new_text[:index] + "。" + new_text[index + 1:] - - return new_text - - def _split(self, text: str) -> List[str]: - text = re.sub(r'[《》【】<=>{}()()#&@“”^_|…\\]', '', text) - # 替换英文句子的句号 "." --> "。" 用于后续分句 - text = self._replace(text) - text = self.SENTENCE_SPLITOR.sub(r'\1\n', text) - text = text.strip() - sentences = [sentence.strip() for sentence in re.split(r'\n+', text)] - return sentences - - def _distinguish(self, text: str) -> List[str]: + def get_segment(self, text: str) -> List[str]: # sentence --> [ch_part, en_part, ch_part, ...] - segments = [] types = [] - flag = 0 temp_seg = "" temp_lang = "" # Determine the type of each character. type: blank, chinese, alphabet, number, unk and point. for ch in text: - if ch == ".": - types.append("point") - elif self.is_chinese(ch): + if self.is_chinese(ch): types.append("zh") elif self.is_alphabet(ch): types.append("en") - elif ch == " ": - types.append("blank") - elif self.is_number(ch): - types.append("num") else: - types.append("unk") + types.append("other") assert len(types) == len(text) for i in range(len(types)): - # find the first char of the seg if flag == 0: - # 首个字符是中文,英文或者数字 - if types[i] == "zh" or types[i] == "en" or types[i] == "num": - temp_seg += text[i] - temp_lang = types[i] - flag = 1 + temp_seg += text[i] + temp_lang = types[i] + flag = 1 else: - # 数字和小数点均与前面的字符合并,类型属于前面一个字符的类型 - if types[i] == temp_lang or types[i] == "num" or types[ - i] == "point": - temp_seg += text[i] - - # 数字与后面的任意字符都拼接 - elif temp_lang == "num": - temp_seg += text[i] - if types[i] == "zh" or types[i] == "en": + if temp_lang == "other": + if types[i] == temp_lang: + temp_seg += text[i] + else: + temp_seg += text[i] temp_lang = types[i] - # 如果是空格则与前面字符拼接 - elif types[i] == "blank": - temp_seg += text[i] - - elif types[i] == "unk": - pass - else: - segments.append((temp_seg, temp_lang)) - - if types[i] == "zh" or types[i] == "en": + if types[i] == temp_lang: + temp_seg += text[i] + elif types[i] == "other": + temp_seg += text[i] + else: + segments.append((temp_seg, temp_lang)) temp_seg = text[i] temp_lang = types[i] flag = 1 - else: - flag = 0 - temp_seg = "" - temp_lang = "" segments.append((temp_seg, temp_lang)) @@ -167,34 +107,30 @@ class MixFrontend(): add_sp: bool=True, to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]: - sentences = self._split(sentence) + segments = self.get_segment(sentence) + phones_list = [] result = {} - for text in sentences: - phones_seg = [] - segments = self._distinguish(text) - for seg in segments: - content = seg[0] - lang = seg[1] - if content != '': - if lang == "en": - input_ids = self.en_frontend.get_input_ids( - content, merge_sentences=True, to_tensor=to_tensor) - else: - input_ids = self.zh_frontend.get_input_ids( - content, - merge_sentences=True, - get_tone_ids=get_tone_ids, - to_tensor=to_tensor) - - phones_seg.append(input_ids["phone_ids"][0]) - if add_sp: - phones_seg.append(self.sp_id_tensor) - - if phones_seg == []: - phones_seg.append(self.sp_id_tensor) - phones = paddle.concat(phones_seg) - phones_list.append(phones) + + for seg in segments: + content = seg[0] + lang = seg[1] + if content != '': + if lang == "en": + input_ids = self.en_frontend.get_input_ids( + content, merge_sentences=False, to_tensor=to_tensor) + else: + input_ids = self.zh_frontend.get_input_ids( + content, + merge_sentences=False, + get_tone_ids=get_tone_ids, + to_tensor=to_tensor) + if add_sp: + input_ids["phone_ids"][-1] = paddle.concat( + [input_ids["phone_ids"][-1], self.sp_id_tensor]) + + for phones in input_ids["phone_ids"]: + phones_list.append(phones) if merge_sentences: merge_list = paddle.concat(phones_list)