fix mix frontend, test=tts (#2299)

pull/2304/head
liangym 2 years ago committed by GitHub
parent 25b96405df
commit 043b21d3b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -61,7 +61,8 @@ class MixFrontend():
return False
def is_end(self, before_char, after_char) -> bool:
if ((self.is_alphabet(before_char) or before_char == " ") and (self.is_alphabet(after_char) or after_char == " ")):
if ((self.is_alphabet(before_char) or before_char == " ") and
(self.is_alphabet(after_char) or after_char == " ")):
return True
else:
return False
@ -86,10 +87,11 @@ class MixFrontend():
if point_index == 0 or point_index == len(text) - 1:
new_text = text
else:
if not self.is_end(text[point_index - 1], text[point_index + 1]):
if not self.is_end(text[point_index - 1], text[point_index +
1]):
new_text = text
else:
new_text = text[: point_index] + "" + text[point_index + 1:]
new_text = text[:point_index] + "" + text[point_index + 1:]
elif len(point_indexs) == 2:
first_index = point_indexs[0]
@ -97,7 +99,8 @@ class MixFrontend():
# first
if first_index != 0:
if not self.is_end(text[first_index - 1], text[first_index + 1]):
if not self.is_end(text[first_index - 1], text[first_index +
1]):
new_text += (text[:first_index] + ".")
else:
new_text += (text[:first_index] + "")
@ -106,18 +109,20 @@ class MixFrontend():
# last
if end_index != len(text) - 1:
if not self.is_end(text[end_index - 1], text[end_index + 1]):
new_text += text[point_indexs[-2] + 1 : ]
new_text += text[point_indexs[-2] + 1:]
else:
new_text += (text[point_indexs[-2] + 1 : end_index] + "" + text[end_index + 1 : ])
new_text += (text[point_indexs[-2] + 1:end_index] + "" +
text[end_index + 1:])
else:
new_text += "."
new_text += "."
else:
first_index = point_indexs[0]
end_index = point_indexs[-1]
# first
if first_index != 0:
if not self.is_end(text[first_index - 1], text[first_index + 1]):
if not self.is_end(text[first_index - 1], text[first_index +
1]):
new_text += (text[:first_index] + ".")
else:
new_text += (text[:first_index] + "")
@ -126,16 +131,20 @@ class MixFrontend():
# middle
for j in range(1, len(point_indexs) - 1):
point_index = point_indexs[j]
if not self.is_end(text[point_index - 1], text[point_index + 1]):
new_text += (text[point_indexs[j-1] + 1 : point_index] + ".")
if not self.is_end(text[point_index - 1], text[point_index +
1]):
new_text += (
text[point_indexs[j - 1] + 1:point_index] + ".")
else:
new_text += (text[point_indexs[j-1] + 1 : point_index] + "")
new_text += (
text[point_indexs[j - 1] + 1:point_index] + "")
# last
if end_index != len(text) - 1:
if not self.is_end(text[end_index - 1], text[end_index + 1]):
new_text += text[point_indexs[-2] + 1 : ]
new_text += text[point_indexs[-2] + 1:]
else:
new_text += (text[point_indexs[-2] + 1 : end_index] + "" + text[end_index + 1 : ])
new_text += (text[point_indexs[-2] + 1:end_index] + "" +
text[end_index + 1:])
else:
new_text += "."
@ -224,7 +233,7 @@ class MixFrontend():
def get_input_ids(self,
sentence: str,
merge_sentences: bool=True,
merge_sentences: bool=False,
get_tone_ids: bool=False,
add_sp: bool=True,
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
@ -232,28 +241,29 @@ class MixFrontend():
sentences = self._split(sentence)
phones_list = []
result = {}
for text in sentences:
phones_seg = []
segments = self._distinguish(text)
for seg in segments:
content = seg[0]
lang = seg[1]
if lang == "zh":
input_ids = self.zh_frontend.get_input_ids(
content,
merge_sentences=True,
get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
elif lang == "en":
input_ids = self.en_frontend.get_input_ids(
content, merge_sentences=True, to_tensor=to_tensor)
phones_seg.append(input_ids["phone_ids"][0])
if add_sp:
phones_seg.append(self.sp_id_tensor)
if content != '':
if lang == "en":
input_ids = self.en_frontend.get_input_ids(
content, merge_sentences=True, to_tensor=to_tensor)
else:
input_ids = self.zh_frontend.get_input_ids(
content,
merge_sentences=True,
get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
phones_seg.append(input_ids["phone_ids"][0])
if add_sp:
phones_seg.append(self.sp_id_tensor)
if phones_seg == []:
phones_seg.append(self.sp_id_tensor)
phones = paddle.concat(phones_seg)
phones_list.append(phones)

Loading…
Cancel
Save