diff --git a/paddlespeech/t2s/frontend/g2pw/dataset.py b/paddlespeech/t2s/frontend/g2pw/dataset.py index 8125f71f..ab715dc3 100644 --- a/paddlespeech/t2s/frontend/g2pw/dataset.py +++ b/paddlespeech/t2s/frontend/g2pw/dataset.py @@ -1,65 +1,95 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Credits - This code is modified from https://github.com/GitYCC/g2pW + This code is modified from https://github.com/GitYCC/g2pW """ import numpy as np + from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map ANCHOR_CHAR = '▁' -def prepare_onnx_input(tokenizer, labels, char2phonemes, chars, texts, query_ids, phonemes=None, pos_tags=None, - use_mask=False, use_char_phoneme=False, use_pos=False, window_size=None, max_len=512): - if window_size is not None: - truncated_texts, truncated_query_ids = _truncate_texts(window_size, texts, query_ids) - - input_ids = [] - token_type_ids = [] - attention_masks = [] - phoneme_masks = [] - char_ids = [] - position_ids = [] - - for idx in range(len(texts)): - text = (truncated_texts if window_size else texts)[idx].lower() - query_id = (truncated_query_ids if window_size else query_ids)[idx] - - try: - tokens, text2token, token2text = tokenize_and_map(tokenizer, text) - except Exception: - print(f'warning: text "{text}" is invalid') - return {} - - text, query_id, tokens, text2token, token2text = _truncate(max_len, text, query_id, tokens, text2token, token2text) - - processed_tokens = ['[CLS]'] + tokens + ['[SEP]'] - - input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) - token_type_id = list(np.zeros((len(processed_tokens),), dtype=int)) - attention_mask = list(np.ones((len(processed_tokens),), dtype=int)) - - query_char = text[query_id] - phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \ - if use_mask else [1] * len(labels) - char_id = chars.index(query_char) - position_id = text2token[query_id] + 1 # [CLS] token locate at first place - - input_ids.append(input_id) - token_type_ids.append(token_type_id) - attention_masks.append(attention_mask) - phoneme_masks.append(phoneme_mask) - char_ids.append(char_id) - position_ids.append(position_id) - - outputs = { - 'input_ids': np.array(input_ids), - 'token_type_ids': np.array(token_type_ids), - 'attention_masks': np.array(attention_masks), - 'phoneme_masks': np.array(phoneme_masks).astype(np.float32), - 'char_ids': np.array(char_ids), - 'position_ids': np.array(position_ids), - } - return outputs +def prepare_onnx_input(tokenizer, + labels, + char2phonemes, + chars, + texts, + query_ids, + phonemes=None, + pos_tags=None, + use_mask=False, + use_char_phoneme=False, + use_pos=False, + window_size=None, + max_len=512): + if window_size is not None: + truncated_texts, truncated_query_ids = _truncate_texts(window_size, + texts, query_ids) + + input_ids = [] + token_type_ids = [] + attention_masks = [] + phoneme_masks = [] + char_ids = [] + position_ids = [] + + for idx in range(len(texts)): + text = (truncated_texts if window_size else texts)[idx].lower() + query_id = (truncated_query_ids if window_size else query_ids)[idx] + + try: + tokens, text2token, token2text = tokenize_and_map(tokenizer, text) + except Exception: + print(f'warning: text "{text}" is invalid') + return {} + + text, query_id, tokens, text2token, token2text = _truncate( + max_len, text, query_id, tokens, text2token, token2text) + + processed_tokens = ['[CLS]'] + tokens + ['[SEP]'] + + input_id = list( + np.array(tokenizer.convert_tokens_to_ids(processed_tokens))) + token_type_id = list(np.zeros((len(processed_tokens), ), dtype=int)) + attention_mask = list(np.ones((len(processed_tokens), ), dtype=int)) + + query_char = text[query_id] + phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \ + if use_mask else [1] * len(labels) + char_id = chars.index(query_char) + position_id = text2token[ + query_id] + 1 # [CLS] token locate at first place + + input_ids.append(input_id) + token_type_ids.append(token_type_id) + attention_masks.append(attention_mask) + phoneme_masks.append(phoneme_mask) + char_ids.append(char_id) + position_ids.append(position_id) + + outputs = { + 'input_ids': np.array(input_ids), + 'token_type_ids': np.array(token_type_ids), + 'attention_masks': np.array(attention_masks), + 'phoneme_masks': np.array(phoneme_masks).astype(np.float32), + 'char_ids': np.array(char_ids), + 'position_ids': np.array(position_ids), + } + return outputs + def _truncate_texts(window_size, texts, query_ids): truncated_texts = [] @@ -74,6 +104,7 @@ def _truncate_texts(window_size, texts, query_ids): truncated_query_ids.append(truncated_query_id) return truncated_texts, truncated_query_ids + def _truncate(max_len, text, query_id, tokens, text2token, token2text): truncate_len = max_len - 2 if len(tokens) <= truncate_len: @@ -95,13 +126,11 @@ def _truncate(max_len, text, query_id, tokens, text2token, token2text): start = token2text[token_start][0] end = token2text[token_end - 1][1] - return ( - text[start:end], - query_id - start, - tokens[token_start:token_end], - [i - token_start if i is not None else None for i in text2token[start:end]], - [(s - start, e - start) for s, e in token2text[token_start:token_end]] - ) + return (text[start:end], query_id - start, tokens[token_start:token_end], [ + i - token_start if i is not None else None + for i in text2token[start:end] + ], [(s - start, e - start) for s, e in token2text[token_start:token_end]]) + def prepare_data(sent_path, lb_path=None): raw_texts = open(sent_path).read().rstrip().split('\n') @@ -125,7 +154,8 @@ def get_phoneme_labels(polyphonic_chars): def get_char_phoneme_labels(polyphonic_chars): - labels = sorted(list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars]))) + labels = sorted( + list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars]))) char2phonemes = {} for char, phoneme in polyphonic_chars: if char not in char2phonemes: diff --git a/paddlespeech/t2s/frontend/g2pw/onnx_api.py b/paddlespeech/t2s/frontend/g2pw/onnx_api.py index ace943f2..3a406ad2 100644 --- a/paddlespeech/t2s/frontend/g2pw/onnx_api.py +++ b/paddlespeech/t2s/frontend/g2pw/onnx_api.py @@ -1,38 +1,54 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Credits - This code is modified from https://github.com/GitYCC/g2pW + This code is modified from https://github.com/GitYCC/g2pW """ -import os import json -import onnxruntime -import numpy as np +import os +import numpy as np +import onnxruntime from opencc import OpenCC -from pypinyin import pinyin, lazy_pinyin, Style from paddlenlp.transformers import BertTokenizer -from paddlespeech.utils.env import MODEL_HOME -from paddlespeech.t2s.frontend.g2pw.dataset import prepare_data,\ - prepare_onnx_input,\ - get_phoneme_labels,\ - get_char_phoneme_labels -from paddlespeech.t2s.frontend.g2pw.utils import load_config +from pypinyin import pinyin +from pypinyin import Style + from paddlespeech.cli.utils import download_and_decompress from paddlespeech.resource.pretrained_models import g2pw_onnx_models +from paddlespeech.t2s.frontend.g2pw.dataset import get_char_phoneme_labels +from paddlespeech.t2s.frontend.g2pw.dataset import get_phoneme_labels +from paddlespeech.t2s.frontend.g2pw.dataset import prepare_onnx_input +from paddlespeech.t2s.frontend.g2pw.utils import load_config +from paddlespeech.utils.env import MODEL_HOME def predict(session, onnx_input, labels): all_preds = [] all_confidences = [] - probs = session.run([],{"input_ids": onnx_input['input_ids'], - "token_type_ids":onnx_input['token_type_ids'], - "attention_mask":onnx_input['attention_masks'], - "phoneme_mask":onnx_input['phoneme_masks'], - "char_ids":onnx_input['char_ids'], - "position_ids":onnx_input['position_ids']})[0] - - preds = np.argmax(probs,axis=1).tolist() + probs = session.run([], { + "input_ids": onnx_input['input_ids'], + "token_type_ids": onnx_input['token_type_ids'], + "attention_mask": onnx_input['attention_masks'], + "phoneme_mask": onnx_input['phoneme_masks'], + "char_ids": onnx_input['char_ids'], + "position_ids": onnx_input['position_ids'] + })[0] + + preds = np.argmax(probs, axis=1).tolist() max_probs = [] - for index,arr in zip(preds,probs.tolist()): + for index, arr in zip(preds, probs.tolist()): max_probs.append(arr[index]) all_preds += [labels[pred] for pred in preds] all_confidences += max_probs @@ -41,39 +57,69 @@ def predict(session, onnx_input, labels): class G2PWOnnxConverter: - def __init__(self, model_dir = MODEL_HOME, style='bopomofo', model_source=None, enable_non_tradional_chinese=False): + def __init__(self, + model_dir=MODEL_HOME, + style='bopomofo', + model_source=None, + enable_non_tradional_chinese=False): if not os.path.exists(os.path.join(model_dir, 'G2PWModel/g2pW.onnx')): - uncompress_path = download_and_decompress(g2pw_onnx_models['G2PWModel']['1.0'],model_dir) + uncompress_path = download_and_decompress( + g2pw_onnx_models['G2PWModel']['1.0'], model_dir) sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL sess_options.intra_op_num_threads = 2 - self.session_g2pW = onnxruntime.InferenceSession(os.path.join(model_dir, 'G2PWModel/g2pW.onnx'),sess_options=sess_options) - self.config = load_config(os.path.join(model_dir, 'G2PWModel/config.py'), use_default=True) + self.session_g2pW = onnxruntime.InferenceSession( + os.path.join(model_dir, 'G2PWModel/g2pW.onnx'), + sess_options=sess_options) + self.config = load_config( + os.path.join(model_dir, 'G2PWModel/config.py'), use_default=True) self.model_source = model_source if model_source else self.config.model_source self.enable_opencc = enable_non_tradional_chinese self.tokenizer = BertTokenizer.from_pretrained(self.config.model_source) - polyphonic_chars_path = os.path.join(model_dir, 'G2PWModel/POLYPHONIC_CHARS.txt') - monophonic_chars_path = os.path.join(model_dir, 'G2PWModel/MONOPHONIC_CHARS.txt') - self.polyphonic_chars = [line.split('\t') for line in open(polyphonic_chars_path,encoding='utf-8').read().strip().split('\n')] - self.monophonic_chars = [line.split('\t') for line in open(monophonic_chars_path,encoding='utf-8').read().strip().split('\n')] - self.labels, self.char2phonemes = get_char_phoneme_labels(self.polyphonic_chars) if self.config.use_char_phoneme else get_phoneme_labels(self.polyphonic_chars) + polyphonic_chars_path = os.path.join(model_dir, + 'G2PWModel/POLYPHONIC_CHARS.txt') + monophonic_chars_path = os.path.join(model_dir, + 'G2PWModel/MONOPHONIC_CHARS.txt') + self.polyphonic_chars = [ + line.split('\t') + for line in open(polyphonic_chars_path, encoding='utf-8').read() + .strip().split('\n') + ] + self.monophonic_chars = [ + line.split('\t') + for line in open(monophonic_chars_path, encoding='utf-8').read() + .strip().split('\n') + ] + self.labels, self.char2phonemes = get_char_phoneme_labels( + self.polyphonic_chars + ) if self.config.use_char_phoneme else get_phoneme_labels( + self.polyphonic_chars) self.chars = sorted(list(self.char2phonemes.keys())) - self.pos_tags = ['UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI'] - - with open(os.path.join(model_dir,'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'), 'r',encoding='utf-8') as fr: + self.pos_tags = [ + 'UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI' + ] + + with open( + os.path.join(model_dir, + 'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'), + 'r', + encoding='utf-8') as fr: self.bopomofo_convert_dict = json.load(fr) self.style_convert_func = { 'bopomofo': lambda x: x, 'pinyin': self._convert_bopomofo_to_pinyin, }[style] - with open(os.path.join(model_dir,'G2PWModel/char_bopomofo_dict.json'), 'r',encoding='utf-8') as fr: + with open( + os.path.join(model_dir, 'G2PWModel/char_bopomofo_dict.json'), + 'r', + encoding='utf-8') as fr: self.char_bopomofo_dict = json.load(fr) if self.enable_opencc: @@ -100,15 +146,23 @@ class G2PWOnnxConverter: assert len(translated_sent) == len(sent) translated_sentences.append(translated_sent) sentences = translated_sentences - - texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences) + + texts, query_ids, sent_ids, partial_results = self._prepare_data( + sentences) if len(texts) == 0: # sentences no polyphonic words return partial_results - onnx_input = prepare_onnx_input(self.tokenizer, self.labels, self.char2phonemes, self.chars, texts, query_ids, - use_mask=self.config.use_mask, use_char_phoneme=self.config.use_char_phoneme, - window_size=None) + onnx_input = prepare_onnx_input( + self.tokenizer, + self.labels, + self.char2phonemes, + self.chars, + texts, + query_ids, + use_mask=self.config.use_mask, + use_char_phoneme=self.config.use_char_phoneme, + window_size=None) preds, confidences = predict(self.session_g2pW, onnx_input, self.labels) if self.config.use_char_phoneme: @@ -123,11 +177,12 @@ class G2PWOnnxConverter: def _prepare_data(self, sentences): polyphonic_chars = set(self.chars) monophonic_chars_dict = { - char: phoneme for char, phoneme in self.monophonic_chars + char: phoneme + for char, phoneme in self.monophonic_chars } texts, query_ids, sent_ids, partial_results = [], [], [], [] for sent_id, sent in enumerate(sentences): - pypinyin_result = pinyin(sent,style=Style.TONE3) + pypinyin_result = pinyin(sent, style=Style.TONE3) partial_result = [None] * len(sent) for i, char in enumerate(sent): if char in polyphonic_chars: @@ -135,9 +190,10 @@ class G2PWOnnxConverter: query_ids.append(i) sent_ids.append(sent_id) elif char in monophonic_chars_dict: - partial_result[i] = self.style_convert_func(monophonic_chars_dict[char]) + partial_result[i] = self.style_convert_func( + monophonic_chars_dict[char]) elif char in self.char_bopomofo_dict: - partial_result[i] = pypinyin_result[i][0] + partial_result[i] = pypinyin_result[i][0] # partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0]) partial_results.append(partial_result) return texts, query_ids, sent_ids, partial_results diff --git a/paddlespeech/t2s/frontend/g2pw/utils.py b/paddlespeech/t2s/frontend/g2pw/utils.py index 771e9007..ad02c4c1 100644 --- a/paddlespeech/t2s/frontend/g2pw/utils.py +++ b/paddlespeech/t2s/frontend/g2pw/utils.py @@ -1,10 +1,22 @@ - +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Credits - This code is modified from https://github.com/GitYCC/g2pW + This code is modified from https://github.com/GitYCC/g2pW """ import re -import sys + def wordize_and_map(text): words = [] @@ -92,7 +104,6 @@ default_config_dict = { 'char-linear': True, 'pos-linear': False, 'char+pos-second': True, - 'char+pos-second_lowrank': False, 'lowrank_size': 0, 'char+pos-second_fm': False, @@ -130,4 +141,4 @@ def load_config(config_path, use_default=False): for dict_k, dict_v in val.items(): if dict_k not in d: d[dict_k] = dict_v - return config \ No newline at end of file + return config diff --git a/paddlespeech/t2s/frontend/zh_frontend.py b/paddlespeech/t2s/frontend/zh_frontend.py index e6129999..9513a459 100644 --- a/paddlespeech/t2s/frontend/zh_frontend.py +++ b/paddlespeech/t2s/frontend/zh_frontend.py @@ -11,15 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re import os -import yaml +import re from typing import Dict from typing import List import jieba.posseg as psg import numpy as np import paddle +import yaml from g2pM import G2pM from pypinyin import lazy_pinyin from pypinyin import load_phrases_dict @@ -58,19 +58,24 @@ def insert_after_character(lst, item): class Polyphonic(): def __init__(self): - with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), - 'polyphonic.yaml'), 'r',encoding='utf-8') as polyphonic_file: + with open( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'polyphonic.yaml'), + 'r', + encoding='utf-8') as polyphonic_file: # 解析yaml polyphonic_dict = yaml.load(polyphonic_file, Loader=yaml.FullLoader) self.polyphonic_words = polyphonic_dict["polyphonic"] - def correct_pronunciation(self,word,pinyin): + def correct_pronunciation(self, word, pinyin): # 词汇被词典收录则返回纠正后的读音 if word in self.polyphonic_words.keys(): pinyin = self.polyphonic_words[word] # 否则返回原读音 return pinyin + class Frontend(): def __init__(self, g2p_model="g2pW", @@ -88,7 +93,8 @@ class Frontend(): elif self.g2p_model == "g2pW": self.corrector = Polyphonic() self.g2pM_model = G2pM() - self.g2pW_model = G2PWOnnxConverter(style='pinyin', enable_non_tradional_chinese=True) + self.g2pW_model = G2PWOnnxConverter( + style='pinyin', enable_non_tradional_chinese=True) self.pinyin2phone = generate_lexicon( with_tone=True, with_erhua=False) @@ -187,7 +193,7 @@ class Frontend(): pinyins = self.g2pW_model(seg)[0] except Exception: # g2pW采用模型采用繁体输入,如果有cover不了的简体词,采用g2pM预测 - print("[%s] not in g2pW dict,use g2pM"%seg) + print("[%s] not in g2pW dict,use g2pM" % seg) pinyins = self.g2pM_model(seg, tone=True, char_split=False) pre_word_length = 0 for word, pos in seg_cut: @@ -199,13 +205,15 @@ class Frontend(): continue word_pinyins = pinyins[pre_word_length:now_word_length] # 矫正发音 - word_pinyins = self.corrector.correct_pronunciation(word,word_pinyins) - for pinyin,char in zip(word_pinyins,word): - if pinyin == None: + word_pinyins = self.corrector.correct_pronunciation( + word, word_pinyins) + for pinyin, char in zip(word_pinyins, word): + if pinyin is None: pinyin = char pinyin = pinyin.replace("u:", "v") if pinyin in self.pinyin2phone: - initial_final_list = self.pinyin2phone[pinyin].split(" ") + initial_final_list = self.pinyin2phone[ + pinyin].split(" ") if len(initial_final_list) == 2: sub_initials.append(initial_final_list[0]) sub_finals.append(initial_final_list[1]) @@ -218,7 +226,7 @@ class Frontend(): sub_finals.append(pinyin) pre_word_length = now_word_length sub_finals = self.tone_modifier.modified_tone(word, pos, - sub_finals) + sub_finals) if with_erhua: sub_initials, sub_finals = self._merge_erhua( sub_initials, sub_finals, word, pos) @@ -231,7 +239,7 @@ class Frontend(): continue sub_initials, sub_finals = self._get_initials_finals(word) sub_finals = self.tone_modifier.modified_tone(word, pos, - sub_finals) + sub_finals) if with_erhua: sub_initials, sub_finals = self._merge_erhua( sub_initials, sub_finals, word, pos)