diff --git a/paddlespeech/t2s/frontend/mix_frontend.py b/paddlespeech/t2s/frontend/mix_frontend.py index 6868d335..a681445c 100644 --- a/paddlespeech/t2s/frontend/mix_frontend.py +++ b/paddlespeech/t2s/frontend/mix_frontend.py @@ -61,7 +61,8 @@ class MixFrontend(): return False def is_end(self, before_char, after_char) -> bool: - if ((self.is_alphabet(before_char) or before_char == " ") and (self.is_alphabet(after_char) or after_char == " ")): + if ((self.is_alphabet(before_char) or before_char == " ") and + (self.is_alphabet(after_char) or after_char == " ")): return True else: return False @@ -86,10 +87,11 @@ class MixFrontend(): if point_index == 0 or point_index == len(text) - 1: new_text = text else: - if not self.is_end(text[point_index - 1], text[point_index + 1]): + if not self.is_end(text[point_index - 1], text[point_index + + 1]): new_text = text else: - new_text = text[: point_index] + "。" + text[point_index + 1:] + new_text = text[:point_index] + "。" + text[point_index + 1:] elif len(point_indexs) == 2: first_index = point_indexs[0] @@ -97,7 +99,8 @@ class MixFrontend(): # first if first_index != 0: - if not self.is_end(text[first_index - 1], text[first_index + 1]): + if not self.is_end(text[first_index - 1], text[first_index + + 1]): new_text += (text[:first_index] + ".") else: new_text += (text[:first_index] + "。") @@ -106,18 +109,20 @@ class MixFrontend(): # last if end_index != len(text) - 1: if not self.is_end(text[end_index - 1], text[end_index + 1]): - new_text += text[point_indexs[-2] + 1 : ] + new_text += text[point_indexs[-2] + 1:] else: - new_text += (text[point_indexs[-2] + 1 : end_index] + "。" + text[end_index + 1 : ]) + new_text += (text[point_indexs[-2] + 1:end_index] + "。" + + text[end_index + 1:]) else: - new_text += "." + new_text += "." else: first_index = point_indexs[0] end_index = point_indexs[-1] # first if first_index != 0: - if not self.is_end(text[first_index - 1], text[first_index + 1]): + if not self.is_end(text[first_index - 1], text[first_index + + 1]): new_text += (text[:first_index] + ".") else: new_text += (text[:first_index] + "。") @@ -126,16 +131,20 @@ class MixFrontend(): # middle for j in range(1, len(point_indexs) - 1): point_index = point_indexs[j] - if not self.is_end(text[point_index - 1], text[point_index + 1]): - new_text += (text[point_indexs[j-1] + 1 : point_index] + ".") + if not self.is_end(text[point_index - 1], text[point_index + + 1]): + new_text += ( + text[point_indexs[j - 1] + 1:point_index] + ".") else: - new_text += (text[point_indexs[j-1] + 1 : point_index] + "。") + new_text += ( + text[point_indexs[j - 1] + 1:point_index] + "。") # last if end_index != len(text) - 1: if not self.is_end(text[end_index - 1], text[end_index + 1]): - new_text += text[point_indexs[-2] + 1 : ] + new_text += text[point_indexs[-2] + 1:] else: - new_text += (text[point_indexs[-2] + 1 : end_index] + "。" + text[end_index + 1 : ]) + new_text += (text[point_indexs[-2] + 1:end_index] + "。" + + text[end_index + 1:]) else: new_text += "." @@ -224,7 +233,7 @@ class MixFrontend(): def get_input_ids(self, sentence: str, - merge_sentences: bool=True, + merge_sentences: bool=False, get_tone_ids: bool=False, add_sp: bool=True, to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]: @@ -232,28 +241,29 @@ class MixFrontend(): sentences = self._split(sentence) phones_list = [] result = {} - for text in sentences: phones_seg = [] segments = self._distinguish(text) for seg in segments: content = seg[0] lang = seg[1] - if lang == "zh": - input_ids = self.zh_frontend.get_input_ids( - content, - merge_sentences=True, - get_tone_ids=get_tone_ids, - to_tensor=to_tensor) - - elif lang == "en": - input_ids = self.en_frontend.get_input_ids( - content, merge_sentences=True, to_tensor=to_tensor) - - phones_seg.append(input_ids["phone_ids"][0]) - if add_sp: - phones_seg.append(self.sp_id_tensor) - + if content != '': + if lang == "en": + input_ids = self.en_frontend.get_input_ids( + content, merge_sentences=True, to_tensor=to_tensor) + else: + input_ids = self.zh_frontend.get_input_ids( + content, + merge_sentences=True, + get_tone_ids=get_tone_ids, + to_tensor=to_tensor) + + phones_seg.append(input_ids["phone_ids"][0]) + if add_sp: + phones_seg.append(self.sp_id_tensor) + + if phones_seg == []: + phones_seg.append(self.sp_id_tensor) phones = paddle.concat(phones_seg) phones_list.append(phones)