|
|
|
@ -61,7 +61,8 @@ class MixFrontend():
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def is_end(self, before_char, after_char) -> bool:
|
|
|
|
|
if ((self.is_alphabet(before_char) or before_char == " ") and (self.is_alphabet(after_char) or after_char == " ")):
|
|
|
|
|
if ((self.is_alphabet(before_char) or before_char == " ") and
|
|
|
|
|
(self.is_alphabet(after_char) or after_char == " ")):
|
|
|
|
|
return True
|
|
|
|
|
else:
|
|
|
|
|
return False
|
|
|
|
@ -86,10 +87,11 @@ class MixFrontend():
|
|
|
|
|
if point_index == 0 or point_index == len(text) - 1:
|
|
|
|
|
new_text = text
|
|
|
|
|
else:
|
|
|
|
|
if not self.is_end(text[point_index - 1], text[point_index + 1]):
|
|
|
|
|
if not self.is_end(text[point_index - 1], text[point_index +
|
|
|
|
|
1]):
|
|
|
|
|
new_text = text
|
|
|
|
|
else:
|
|
|
|
|
new_text = text[: point_index] + "。" + text[point_index + 1:]
|
|
|
|
|
new_text = text[:point_index] + "。" + text[point_index + 1:]
|
|
|
|
|
|
|
|
|
|
elif len(point_indexs) == 2:
|
|
|
|
|
first_index = point_indexs[0]
|
|
|
|
@ -97,7 +99,8 @@ class MixFrontend():
|
|
|
|
|
|
|
|
|
|
# first
|
|
|
|
|
if first_index != 0:
|
|
|
|
|
if not self.is_end(text[first_index - 1], text[first_index + 1]):
|
|
|
|
|
if not self.is_end(text[first_index - 1], text[first_index +
|
|
|
|
|
1]):
|
|
|
|
|
new_text += (text[:first_index] + ".")
|
|
|
|
|
else:
|
|
|
|
|
new_text += (text[:first_index] + "。")
|
|
|
|
@ -106,18 +109,20 @@ class MixFrontend():
|
|
|
|
|
# last
|
|
|
|
|
if end_index != len(text) - 1:
|
|
|
|
|
if not self.is_end(text[end_index - 1], text[end_index + 1]):
|
|
|
|
|
new_text += text[point_indexs[-2] + 1 : ]
|
|
|
|
|
new_text += text[point_indexs[-2] + 1:]
|
|
|
|
|
else:
|
|
|
|
|
new_text += (text[point_indexs[-2] + 1 : end_index] + "。" + text[end_index + 1 : ])
|
|
|
|
|
new_text += (text[point_indexs[-2] + 1:end_index] + "。" +
|
|
|
|
|
text[end_index + 1:])
|
|
|
|
|
else:
|
|
|
|
|
new_text += "."
|
|
|
|
|
new_text += "."
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
first_index = point_indexs[0]
|
|
|
|
|
end_index = point_indexs[-1]
|
|
|
|
|
# first
|
|
|
|
|
if first_index != 0:
|
|
|
|
|
if not self.is_end(text[first_index - 1], text[first_index + 1]):
|
|
|
|
|
if not self.is_end(text[first_index - 1], text[first_index +
|
|
|
|
|
1]):
|
|
|
|
|
new_text += (text[:first_index] + ".")
|
|
|
|
|
else:
|
|
|
|
|
new_text += (text[:first_index] + "。")
|
|
|
|
@ -126,16 +131,20 @@ class MixFrontend():
|
|
|
|
|
# middle
|
|
|
|
|
for j in range(1, len(point_indexs) - 1):
|
|
|
|
|
point_index = point_indexs[j]
|
|
|
|
|
if not self.is_end(text[point_index - 1], text[point_index + 1]):
|
|
|
|
|
new_text += (text[point_indexs[j-1] + 1 : point_index] + ".")
|
|
|
|
|
if not self.is_end(text[point_index - 1], text[point_index +
|
|
|
|
|
1]):
|
|
|
|
|
new_text += (
|
|
|
|
|
text[point_indexs[j - 1] + 1:point_index] + ".")
|
|
|
|
|
else:
|
|
|
|
|
new_text += (text[point_indexs[j-1] + 1 : point_index] + "。")
|
|
|
|
|
new_text += (
|
|
|
|
|
text[point_indexs[j - 1] + 1:point_index] + "。")
|
|
|
|
|
# last
|
|
|
|
|
if end_index != len(text) - 1:
|
|
|
|
|
if not self.is_end(text[end_index - 1], text[end_index + 1]):
|
|
|
|
|
new_text += text[point_indexs[-2] + 1 : ]
|
|
|
|
|
new_text += text[point_indexs[-2] + 1:]
|
|
|
|
|
else:
|
|
|
|
|
new_text += (text[point_indexs[-2] + 1 : end_index] + "。" + text[end_index + 1 : ])
|
|
|
|
|
new_text += (text[point_indexs[-2] + 1:end_index] + "。" +
|
|
|
|
|
text[end_index + 1:])
|
|
|
|
|
else:
|
|
|
|
|
new_text += "."
|
|
|
|
|
|
|
|
|
@ -224,7 +233,7 @@ class MixFrontend():
|
|
|
|
|
|
|
|
|
|
def get_input_ids(self,
|
|
|
|
|
sentence: str,
|
|
|
|
|
merge_sentences: bool=True,
|
|
|
|
|
merge_sentences: bool=False,
|
|
|
|
|
get_tone_ids: bool=False,
|
|
|
|
|
add_sp: bool=True,
|
|
|
|
|
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
|
|
|
|
@ -232,28 +241,29 @@ class MixFrontend():
|
|
|
|
|
sentences = self._split(sentence)
|
|
|
|
|
phones_list = []
|
|
|
|
|
result = {}
|
|
|
|
|
|
|
|
|
|
for text in sentences:
|
|
|
|
|
phones_seg = []
|
|
|
|
|
segments = self._distinguish(text)
|
|
|
|
|
for seg in segments:
|
|
|
|
|
content = seg[0]
|
|
|
|
|
lang = seg[1]
|
|
|
|
|
if lang == "zh":
|
|
|
|
|
input_ids = self.zh_frontend.get_input_ids(
|
|
|
|
|
content,
|
|
|
|
|
merge_sentences=True,
|
|
|
|
|
get_tone_ids=get_tone_ids,
|
|
|
|
|
to_tensor=to_tensor)
|
|
|
|
|
|
|
|
|
|
elif lang == "en":
|
|
|
|
|
input_ids = self.en_frontend.get_input_ids(
|
|
|
|
|
content, merge_sentences=True, to_tensor=to_tensor)
|
|
|
|
|
|
|
|
|
|
phones_seg.append(input_ids["phone_ids"][0])
|
|
|
|
|
if add_sp:
|
|
|
|
|
phones_seg.append(self.sp_id_tensor)
|
|
|
|
|
|
|
|
|
|
if content != '':
|
|
|
|
|
if lang == "en":
|
|
|
|
|
input_ids = self.en_frontend.get_input_ids(
|
|
|
|
|
content, merge_sentences=True, to_tensor=to_tensor)
|
|
|
|
|
else:
|
|
|
|
|
input_ids = self.zh_frontend.get_input_ids(
|
|
|
|
|
content,
|
|
|
|
|
merge_sentences=True,
|
|
|
|
|
get_tone_ids=get_tone_ids,
|
|
|
|
|
to_tensor=to_tensor)
|
|
|
|
|
|
|
|
|
|
phones_seg.append(input_ids["phone_ids"][0])
|
|
|
|
|
if add_sp:
|
|
|
|
|
phones_seg.append(self.sp_id_tensor)
|
|
|
|
|
|
|
|
|
|
if phones_seg == []:
|
|
|
|
|
phones_seg.append(self.sp_id_tensor)
|
|
|
|
|
phones = paddle.concat(phones_seg)
|
|
|
|
|
phones_list.append(phones)
|
|
|
|
|
|
|
|
|
|