Merge branch 'develop' into develop

pull/2221/head
BarryKCL 3 years ago committed by GitHub
commit 2cf8cf6357
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -365,15 +365,15 @@ class Frontend():
print("----------------------------") print("----------------------------")
return phonemes return phonemes
def get_input_ids( def get_input_ids(self,
self,
sentence: str, sentence: str,
merge_sentences: bool=True, merge_sentences: bool=True,
get_tone_ids: bool=False, get_tone_ids: bool=False,
robot: bool=False, robot: bool=False,
print_info: bool=False, print_info: bool=False,
add_blank: bool=False, add_blank: bool=False,
blank_token: str="<pad>") -> Dict[str, List[paddle.Tensor]]: blank_token: str="<pad>",
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
phonemes = self.get_phonemes( phonemes = self.get_phonemes(
sentence, sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
@ -384,19 +384,21 @@ class Frontend():
tones = [] tones = []
temp_phone_ids = [] temp_phone_ids = []
temp_tone_ids = [] temp_tone_ids = []
for part_phonemes in phonemes: for part_phonemes in phonemes:
phones, tones = self._get_phone_tone( phones, tones = self._get_phone_tone(
part_phonemes, get_tone_ids=get_tone_ids) part_phonemes, get_tone_ids=get_tone_ids)
if add_blank: if add_blank:
phones = insert_after_character(phones, blank_token) phones = insert_after_character(phones, blank_token)
if tones: if tones:
tone_ids = self._t2id(tones) tone_ids = self._t2id(tones)
if to_tensor:
tone_ids = paddle.to_tensor(tone_ids) tone_ids = paddle.to_tensor(tone_ids)
temp_tone_ids.append(tone_ids) temp_tone_ids.append(tone_ids)
if phones: if phones:
phone_ids = self._p2id(phones) phone_ids = self._p2id(phones)
# if use paddle.to_tensor() in onnxruntime, the first time will be too low
if to_tensor:
phone_ids = paddle.to_tensor(phone_ids) phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids) temp_phone_ids.append(phone_ids)
if temp_tone_ids: if temp_tone_ids:

Loading…
Cancel
Save