You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
255 lines
11 KiB
255 lines
11 KiB
3 years ago
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
import re
|
||
|
from typing import Dict
|
||
|
from typing import List
|
||
|
|
||
|
import jieba.posseg as psg
|
||
|
import numpy as np
|
||
|
import paddle
|
||
|
from g2pM import G2pM
|
||
|
from pypinyin import lazy_pinyin
|
||
|
from pypinyin import Style
|
||
|
|
||
|
from parakeet.frontend.zh_normalization.text_normlization import TextNormalizer
|
||
|
from parakeet.frontend.generate_lexicon import generate_lexicon
|
||
|
from parakeet.frontend.tone_sandhi import ToneSandhi
|
||
|
|
||
|
|
||
|
class Frontend():
|
||
|
def __init__(self,
|
||
|
g2p_model="pypinyin",
|
||
|
phone_vocab_path=None,
|
||
|
tone_vocab_path=None):
|
||
|
self.tone_modifier = ToneSandhi()
|
||
|
self.text_normalizer = TextNormalizer()
|
||
|
self.punc = ":,;。?!“”‘’':,;.?!"
|
||
|
# g2p_model can be pypinyin and g2pM
|
||
|
self.g2p_model = g2p_model
|
||
|
if self.g2p_model == "g2pM":
|
||
|
self.g2pM_model = G2pM()
|
||
|
self.pinyin2phone = generate_lexicon(
|
||
|
with_tone=True, with_erhua=False)
|
||
|
self.must_erhua = {"小院儿", "胡同儿", "范儿", "老汉儿", "撒欢儿", "寻老礼儿", "妥妥儿"}
|
||
|
self.not_erhua = {
|
||
|
"虐儿", "为儿", "护儿", "瞒儿", "救儿", "替儿", "有儿", "一儿", "我儿", "俺儿", "妻儿",
|
||
|
"拐儿", "聋儿", "乞儿", "患儿", "幼儿", "孤儿", "婴儿", "婴幼儿", "连体儿", "脑瘫儿",
|
||
|
"流浪儿", "体弱儿", "混血儿", "蜜雪儿", "舫儿", "祖儿", "美儿", "应采儿", "可儿", "侄儿",
|
||
|
"孙儿", "侄孙儿", "女儿", "男儿", "红孩儿", "花儿", "虫儿", "马儿", "鸟儿", "猪儿", "猫儿",
|
||
|
"狗儿"
|
||
|
}
|
||
|
self.vocab_phones = {}
|
||
|
self.vocab_tones = {}
|
||
|
if phone_vocab_path:
|
||
|
with open(phone_vocab_path, 'rt') as f:
|
||
|
phn_id = [line.strip().split() for line in f.readlines()]
|
||
|
for phn, id in phn_id:
|
||
|
self.vocab_phones[phn] = int(id)
|
||
|
if tone_vocab_path:
|
||
|
with open(tone_vocab_path, 'rt') as f:
|
||
|
tone_id = [line.strip().split() for line in f.readlines()]
|
||
|
for tone, id in tone_id:
|
||
|
self.vocab_tones[tone] = int(id)
|
||
|
|
||
|
def _get_initials_finals(self, word: str) -> List[List[str]]:
|
||
|
initials = []
|
||
|
finals = []
|
||
|
if self.g2p_model == "pypinyin":
|
||
|
orig_initials = lazy_pinyin(
|
||
|
word, neutral_tone_with_five=True, style=Style.INITIALS)
|
||
|
orig_finals = lazy_pinyin(
|
||
|
word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
|
||
|
for c, v in zip(orig_initials, orig_finals):
|
||
|
if re.match(r'i\d', v):
|
||
|
if c in ['z', 'c', 's']:
|
||
|
v = re.sub('i', 'ii', v)
|
||
|
elif c in ['zh', 'ch', 'sh', 'r']:
|
||
|
v = re.sub('i', 'iii', v)
|
||
|
initials.append(c)
|
||
|
finals.append(v)
|
||
|
elif self.g2p_model == "g2pM":
|
||
|
pinyins = self.g2pM_model(word, tone=True, char_split=False)
|
||
|
for pinyin in pinyins:
|
||
|
pinyin = pinyin.replace("u:", "v")
|
||
|
if pinyin in self.pinyin2phone:
|
||
|
initial_final_list = self.pinyin2phone[pinyin].split(" ")
|
||
|
if len(initial_final_list) == 2:
|
||
|
initials.append(initial_final_list[0])
|
||
|
finals.append(initial_final_list[1])
|
||
|
elif len(initial_final_list) == 1:
|
||
|
initials.append('')
|
||
|
finals.append(initial_final_list[1])
|
||
|
else:
|
||
|
# If it's not pinyin (possibly punctuation) or no conversion is required
|
||
|
initials.append(pinyin)
|
||
|
finals.append(pinyin)
|
||
|
return initials, finals
|
||
|
|
||
|
# if merge_sentences, merge all sentences into one phone sequence
|
||
|
def _g2p(self,
|
||
|
sentences: List[str],
|
||
|
merge_sentences: bool=True,
|
||
|
with_erhua: bool=True) -> List[List[str]]:
|
||
|
segments = sentences
|
||
|
phones_list = []
|
||
|
for seg in segments:
|
||
|
phones = []
|
||
|
seg_cut = psg.lcut(seg)
|
||
|
initials = []
|
||
|
finals = []
|
||
|
seg_cut = self.tone_modifier.pre_merge_for_modify(seg_cut)
|
||
|
for word, pos in seg_cut:
|
||
|
if pos == 'eng':
|
||
|
continue
|
||
|
sub_initials, sub_finals = self._get_initials_finals(word)
|
||
|
sub_finals = self.tone_modifier.modified_tone(word, pos,
|
||
|
sub_finals)
|
||
|
if with_erhua:
|
||
|
sub_initials, sub_finals = self._merge_erhua(
|
||
|
sub_initials, sub_finals, word, pos)
|
||
|
initials.append(sub_initials)
|
||
|
finals.append(sub_finals)
|
||
|
# assert len(sub_initials) == len(sub_finals) == len(word)
|
||
|
initials = sum(initials, [])
|
||
|
finals = sum(finals, [])
|
||
|
|
||
|
for c, v in zip(initials, finals):
|
||
|
# NOTE: post process for pypinyin outputs
|
||
|
# we discriminate i, ii and iii
|
||
|
if c and c not in self.punc:
|
||
|
phones.append(c)
|
||
|
if v and v not in self.punc:
|
||
|
phones.append(v)
|
||
|
# add sp between sentence (replace the last punc with sp)
|
||
|
if initials[-1] in self.punc:
|
||
|
phones.append('sp')
|
||
|
phones_list.append(phones)
|
||
|
if merge_sentences:
|
||
|
merge_list = sum(phones_list, [])
|
||
|
phones_list = []
|
||
|
phones_list.append(merge_list)
|
||
|
return phones_list
|
||
|
|
||
|
def _merge_erhua(self,
|
||
|
initials: List[str],
|
||
|
finals: List[str],
|
||
|
word: str,
|
||
|
pos: str) -> List[List[str]]:
|
||
|
if word not in self.must_erhua and (word in self.not_erhua or
|
||
|
pos in {"a", "j", "nr"}):
|
||
|
return initials, finals
|
||
|
new_initials = []
|
||
|
new_finals = []
|
||
|
assert len(finals) == len(word)
|
||
|
for i, phn in enumerate(finals):
|
||
|
if i == len(finals) - 1 and word[i] == "儿" and phn in {
|
||
|
"er2", "er5"
|
||
|
} and word[-2:] not in self.not_erhua and new_finals:
|
||
|
new_finals[-1] = new_finals[-1][:-1] + "r" + new_finals[-1][-1]
|
||
|
else:
|
||
|
new_finals.append(phn)
|
||
|
new_initials.append(initials[i])
|
||
|
return new_initials, new_finals
|
||
|
|
||
|
def _p2id(self, phonemes: List[str]) -> np.array:
|
||
|
# replace unk phone with sp
|
||
|
phonemes = [
|
||
|
phn if phn in self.vocab_phones else "sp" for phn in phonemes
|
||
|
]
|
||
|
phone_ids = [self.vocab_phones[item] for item in phonemes]
|
||
|
return np.array(phone_ids, np.int64)
|
||
|
|
||
|
def _t2id(self, tones: List[str]) -> np.array:
|
||
|
# replace unk phone with sp
|
||
|
tones = [tone if tone in self.vocab_tones else "0" for tone in tones]
|
||
|
tone_ids = [self.vocab_tones[item] for item in tones]
|
||
|
return np.array(tone_ids, np.int64)
|
||
|
|
||
|
def _get_phone_tone(self, phonemes: List[str],
|
||
|
get_tone_ids: bool=False) -> List[List[str]]:
|
||
|
phones = []
|
||
|
tones = []
|
||
|
if get_tone_ids and self.vocab_tones:
|
||
|
for full_phone in phonemes:
|
||
|
# split tone from finals
|
||
|
match = re.match(r'^(\w+)([012345])$', full_phone)
|
||
|
if match:
|
||
|
phone = match.group(1)
|
||
|
tone = match.group(2)
|
||
|
# if the merged erhua not in the vocab
|
||
|
# assume that the input is ['iaor3'] and 'iaor' not in self.vocab_phones, we split 'iaor' into ['iao','er']
|
||
|
# and the tones accordingly change from ['3'] to ['3','2'], while '2' is the tone of 'er2'
|
||
|
if len(phone) >= 2 and phone != "er" and phone[
|
||
|
-1] == 'r' and phone not in self.vocab_phones and phone[:
|
||
|
-1] in self.vocab_phones:
|
||
|
phones.append(phone[:-1])
|
||
|
phones.append("er")
|
||
|
tones.append(tone)
|
||
|
tones.append("2")
|
||
|
else:
|
||
|
phones.append(phone)
|
||
|
tones.append(tone)
|
||
|
else:
|
||
|
phones.append(full_phone)
|
||
|
tones.append('0')
|
||
|
else:
|
||
|
for phone in phonemes:
|
||
|
# if the merged erhua not in the vocab
|
||
|
# assume that the input is ['iaor3'] and 'iaor' not in self.vocab_phones, change ['iaor3'] to ['iao3','er2']
|
||
|
if len(phone) >= 3 and phone[:-1] != "er" and phone[
|
||
|
-2] == 'r' and phone not in self.vocab_phones and (
|
||
|
phone[:-2] + phone[-1]) in self.vocab_phones:
|
||
|
phones.append((phone[:-2] + phone[-1]))
|
||
|
phones.append("er2")
|
||
|
else:
|
||
|
phones.append(phone)
|
||
|
return phones, tones
|
||
|
|
||
|
def get_phonemes(self,
|
||
|
sentence: str,
|
||
|
merge_sentences: bool=True,
|
||
|
with_erhua: bool=True) -> List[List[str]]:
|
||
|
sentences = self.text_normalizer.normalize(sentence)
|
||
|
phonemes = self._g2p(
|
||
|
sentences, merge_sentences=merge_sentences, with_erhua=with_erhua)
|
||
|
return phonemes
|
||
|
|
||
|
def get_input_ids(
|
||
|
self,
|
||
|
sentence: str,
|
||
|
merge_sentences: bool=True,
|
||
|
get_tone_ids: bool=False) -> Dict[str, List[paddle.Tensor]]:
|
||
|
phonemes = self.get_phonemes(sentence, merge_sentences=merge_sentences)
|
||
|
result = {}
|
||
|
phones = []
|
||
|
tones = []
|
||
|
temp_phone_ids = []
|
||
|
temp_tone_ids = []
|
||
|
for part_phonemes in phonemes:
|
||
|
phones, tones = self._get_phone_tone(
|
||
|
part_phonemes, get_tone_ids=get_tone_ids)
|
||
|
if tones:
|
||
|
tone_ids = self._t2id(tones)
|
||
|
tone_ids = paddle.to_tensor(tone_ids)
|
||
|
temp_tone_ids.append(tone_ids)
|
||
|
if phones:
|
||
|
phone_ids = self._p2id(phones)
|
||
|
phone_ids = paddle.to_tensor(phone_ids)
|
||
|
temp_phone_ids.append(phone_ids)
|
||
|
if temp_tone_ids:
|
||
|
result["tone_ids"] = temp_tone_ids
|
||
|
if temp_phone_ids:
|
||
|
result["phone_ids"] = temp_phone_ids
|
||
|
return result
|