g2pW onnxruntime

pull/2221/head
wangcanlong 3 years ago
parent 2498b9ce66
commit d451d1e099

@ -0,0 +1,2 @@
from paddlespeech.t2s.frontend.g2pw.onnx_api import G2PWOnnxConverter

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

@ -0,0 +1,135 @@
import numpy as np
from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map
ANCHOR_CHAR = ''
def prepare_onnx_input(tokenizer, labels, char2phonemes, chars, texts, query_ids, phonemes=None, pos_tags=None,
use_mask=False, use_char_phoneme=False, use_pos=False, window_size=None, max_len=512):
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(window_size, texts, query_ids)
input_ids = []
token_type_ids = []
attention_masks = []
phoneme_masks = []
char_ids = []
position_ids = []
for idx in range(len(texts)):
text = (truncated_texts if window_size else texts)[idx].lower()
query_id = (truncated_query_ids if window_size else query_ids)[idx]
try:
tokens, text2token, token2text = tokenize_and_map(tokenizer, text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
text, query_id, tokens, text2token, token2text = _truncate(max_len, text, query_id, tokens, text2token, token2text)
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
query_char = text[query_id]
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
if use_mask else [1] * len(labels)
char_id = chars.index(query_char)
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
input_ids.append(input_id)
token_type_ids.append(token_type_id)
attention_masks.append(attention_mask)
phoneme_masks.append(phoneme_mask)
char_ids.append(char_id)
position_ids.append(position_id)
outputs = {
'input_ids': np.array(input_ids),
'token_type_ids': np.array(token_type_ids),
'attention_masks': np.array(attention_masks),
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
'char_ids': np.array(char_ids),
'position_ids': np.array(position_ids),
}
return outputs
def _truncate_texts(window_size, texts, query_ids):
truncated_texts = []
truncated_query_ids = []
for text, query_id in zip(texts, query_ids):
start = max(0, query_id - window_size // 2)
end = min(len(text), query_id + window_size // 2)
truncated_text = text[start:end]
truncated_texts.append(truncated_text)
truncated_query_id = query_id - start
truncated_query_ids.append(truncated_query_id)
return truncated_texts, truncated_query_ids
def _truncate(max_len, text, query_id, tokens, text2token, token2text):
truncate_len = max_len - 2
if len(tokens) <= truncate_len:
return (text, query_id, tokens, text2token, token2text)
token_position = text2token[query_id]
token_start = token_position - truncate_len // 2
token_end = token_start + truncate_len
font_exceed_dist = -token_start
back_exceed_dist = token_end - len(tokens)
if font_exceed_dist > 0:
token_start += font_exceed_dist
token_end += font_exceed_dist
elif back_exceed_dist > 0:
token_start -= back_exceed_dist
token_end -= back_exceed_dist
start = token2text[token_start][0]
end = token2text[token_end - 1][1]
return (
text[start:end],
query_id - start,
tokens[token_start:token_end],
[i - token_start if i is not None else None for i in text2token[start:end]],
[(s - start, e - start) for s, e in token2text[token_start:token_end]]
)
def prepare_data(sent_path, lb_path=None):
raw_texts = open(sent_path).read().rstrip().split('\n')
query_ids = [raw.index(ANCHOR_CHAR) for raw in raw_texts]
texts = [raw.replace(ANCHOR_CHAR, '') for raw in raw_texts]
if lb_path is None:
return texts, query_ids
else:
phonemes = open(lb_path).read().rstrip().split('\n')
return texts, query_ids, phonemes
def get_phoneme_labels(polyphonic_chars):
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
if char not in char2phonemes:
char2phonemes[char] = []
char2phonemes[char].append(labels.index(phoneme))
return labels, char2phonemes
def get_char_phoneme_labels(polyphonic_chars):
labels = sorted(list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
char2phonemes = {}
for char, phoneme in polyphonic_chars:
if char not in char2phonemes:
char2phonemes[char] = []
char2phonemes[char].append(labels.index(f'{char} {phoneme}'))
return labels, char2phonemes
def prepare_pos(pos_path):
return open(pos_path).read().rstrip().split('\n')

@ -0,0 +1,146 @@
import os
import json
import requests
import zipfile
import onnxruntime
import numpy as np
from io import BytesIO
import shutil
from transformers import BertTokenizer
try:
from opencc import OpenCC
except:
pass
from paddlespeech.t2s.frontend.g2pw.dataset import prepare_data, prepare_onnx_input, get_phoneme_labels, get_char_phoneme_labels
from paddlespeech.t2s.frontend.g2pw.utils import load_config
MODEL_URL = 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/g2p/G2PWModel.tar'
def predict(session, onnx_input, labels):
all_preds = []
all_confidences = []
probs = session.run([],{"input_ids": onnx_input['input_ids'],
"token_type_ids":onnx_input['token_type_ids'],
"attention_mask":onnx_input['attention_masks'],
"phoneme_mask":onnx_input['phoneme_masks'],
"char_ids":onnx_input['char_ids'],
"position_ids":onnx_input['position_ids']})[0]
preds = np.argmax(probs,axis=1).tolist()
max_probs = []
for index,arr in zip(preds,probs.tolist()):
max_probs.append(arr[index])
all_preds += [labels[pred] for pred in preds]
all_confidences += max_probs
return all_preds, all_confidences
def download_model(model_dir):
wget_shell = "cd %s && wget %s"%(model_dir,MODEL_URL)
os.system(wget_shell)
shell = "cd %s ;tar -xvf %s;cd %s/G2PWModel;rm -rf .*" % (model_dir,MODEL_URL.split("/")[-1], model_dir)
os.system(shell)
rm_shell = "cd %s && rm -rf %s"%(model_dir,MODEL_URL.split("/")[-1])
os.system(rm_shell)
class G2PWOnnxConverter:
def __init__(self, style='bopomofo', model_source=None, enable_non_tradional_chinese=False):
model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.exists(os.path.join(model_dir, 'G2PWModel/g2pW.onnx')):
download_model(model_dir)
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(model_dir, 'G2PWModel/g2pW.onnx'))
self.config = load_config(os.path.join(model_dir, 'G2PWModel/config.py'), use_default=True)
self.model_source = model_source if model_source else self.config.model_source
self.enable_opencc = enable_non_tradional_chinese
self.tokenizer = BertTokenizer.from_pretrained(self.config.model_source)
polyphonic_chars_path = os.path.join(model_dir, 'G2PWModel/POLYPHONIC_CHARS.txt')
monophonic_chars_path = os.path.join(model_dir, 'G2PWModel/MONOPHONIC_CHARS.txt')
self.polyphonic_chars = [line.split('\t') for line in open(polyphonic_chars_path,encoding='utf-8').read().strip().split('\n')]
self.monophonic_chars = [line.split('\t') for line in open(monophonic_chars_path,encoding='utf-8').read().strip().split('\n')]
self.labels, self.char2phonemes = get_char_phoneme_labels(self.polyphonic_chars) if self.config.use_char_phoneme else get_phoneme_labels(self.polyphonic_chars)
self.chars = sorted(list(self.char2phonemes.keys()))
self.pos_tags = ['UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI']
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)),
'bopomofo_to_pinyin_wo_tune_dict.json'), 'r',encoding='utf-8') as fr:
self.bopomofo_convert_dict = json.load(fr)
self.style_convert_func = {
'bopomofo': lambda x: x,
'pinyin': self._convert_bopomofo_to_pinyin,
}[style]
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)),
'char_bopomofo_dict.json'), 'r',encoding='utf-8') as fr:
self.char_bopomofo_dict = json.load(fr)
if self.enable_opencc:
self.cc = OpenCC('s2tw')
def _convert_bopomofo_to_pinyin(self, bopomofo):
tone = bopomofo[-1]
assert tone in '12345'
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
if component:
return component + tone
else:
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
return None
def __call__(self, sentences):
if isinstance(sentences, str):
sentences = [sentences]
if self.enable_opencc:
translated_sentences = []
for sent in sentences:
translated_sent = self.cc.convert(sent)
assert len(translated_sent) == len(sent)
translated_sentences.append(translated_sent)
sentences = translated_sentences
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences)
onnx_input = prepare_onnx_input(self.tokenizer, self.labels, self.char2phonemes, self.chars, texts, query_ids,
use_mask=self.config.use_mask, use_char_phoneme=self.config.use_char_phoneme,
window_size=self.config.window_size)
preds, confidences = predict(self.session_g2pW, onnx_input, self.labels)
if self.config.use_char_phoneme:
preds = [pred.split(' ')[1] for pred in preds]
results = partial_results
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
results[sent_id][query_id] = self.style_convert_func(pred)
return results
def _prepare_data(self, sentences):
polyphonic_chars = set(self.chars)
monophonic_chars_dict = {
char: phoneme for char, phoneme in self.monophonic_chars
}
texts, query_ids, sent_ids, partial_results = [], [], [], []
for sent_id, sent in enumerate(sentences):
partial_result = [None] * len(sent)
for i, char in enumerate(sent):
if char in polyphonic_chars:
texts.append(sent)
query_ids.append(i)
sent_ids.append(sent_id)
elif char in monophonic_chars_dict:
partial_result[i] = self.style_convert_func(monophonic_chars_dict[char])
elif char in self.char_bopomofo_dict:
partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
partial_results.append(partial_result)
return texts, query_ids, sent_ids, partial_results

@ -0,0 +1,161 @@
import re
import logging
import sys
class RunningAverage:
def __init__(self):
self.values = []
def add(self, val):
self.values.append(val)
def add_all(self, vals):
self.values += vals
def get(self):
if len(self.values) == 0:
return None
return sum(self.values) / len(self.values)
def flush(self):
self.values = []
def wordize_and_map(text):
words = []
index_map_from_text_to_word = []
index_map_from_word_to_text = []
while len(text) > 0:
match_space = re.match(r'^ +', text)
if match_space:
space_str = match_space.group(0)
index_map_from_text_to_word += [None] * len(space_str)
text = text[len(space_str):]
continue
match_en = re.match(r'^[a-zA-Z0-9]+', text)
if match_en:
en_word = match_en.group(0)
word_start_pos = len(index_map_from_text_to_word)
word_end_pos = word_start_pos + len(en_word)
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
index_map_from_text_to_word += [len(words)] * len(en_word)
words.append(en_word)
text = text[len(en_word):]
else:
word_start_pos = len(index_map_from_text_to_word)
word_end_pos = word_start_pos + 1
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
index_map_from_text_to_word += [len(words)]
words.append(text[0])
text = text[1:]
return words, index_map_from_text_to_word, index_map_from_word_to_text
def tokenize_and_map(tokenizer, text):
words, text2word, word2text = wordize_and_map(text)
tokens = []
index_map_from_token_to_text = []
for word, (word_start, word_end) in zip(words, word2text):
word_tokens = tokenizer.tokenize(word)
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
index_map_from_token_to_text.append((word_start, word_end))
tokens.append('[UNK]')
else:
current_word_start = word_start
for word_token in word_tokens:
word_token_len = len(re.sub(r'^##', '', word_token))
index_map_from_token_to_text.append(
(current_word_start, current_word_start + word_token_len))
current_word_start = current_word_start + word_token_len
tokens.append(word_token)
index_map_from_text_to_token = text2word
for i, (token_start, token_end) in enumerate(index_map_from_token_to_text):
for token_pos in range(token_start, token_end):
index_map_from_text_to_token[token_pos] = i
return tokens, index_map_from_text_to_token, index_map_from_token_to_text
def _load_config(config_path):
import importlib.util
spec = importlib.util.spec_from_file_location('__init__', config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)
return config
default_config_dict = {
'manual_seed': 1313,
'model_source': 'bert-base-chinese',
'window_size': 32,
'num_workers': 2,
'use_mask': True,
'use_char_phoneme': False,
'use_conditional': True,
'param_conditional': {
'affect_location': 'softmax',
'bias': True,
'char-linear': True,
'pos-linear': False,
'char+pos-second': True,
'char+pos-second_lowrank': False,
'lowrank_size': 0,
'char+pos-second_fm': False,
'fm_size': 0,
'fix_mode': None,
'count_json': 'train.count.json'
},
'lr': 5e-5,
'val_interval': 200,
'num_iter': 10000,
'use_focal': False,
'param_focal': {
'alpha': 0.0,
'gamma': 0.7
},
'use_pos': True,
'param_pos ': {
'weight': 0.1,
'pos_joint_training': True,
'train_pos_path': 'train.pos',
'valid_pos_path': 'dev.pos',
'test_pos_path': 'test.pos'
}
}
def load_config(config_path, use_default=False):
config = _load_config(config_path)
if use_default:
for attr, val in default_config_dict.items():
if not hasattr(config, attr):
setattr(config, attr, val)
elif isinstance(val, dict):
d = getattr(config, attr)
for dict_k, dict_v in val.items():
if dict_k not in d:
d[dict_k] = dict_v
return config
def get_logger(file_path):
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
output_file_handler = logging.FileHandler(file_path)
stdout_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(output_file_handler)
logger.addHandler(stdout_handler)
return logger

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re import re
import os
import yaml import yaml
from typing import Dict from typing import Dict
from typing import List from typing import List
@ -20,13 +21,13 @@ import jieba.posseg as psg
import numpy as np import numpy as np
import paddle import paddle
from g2pM import G2pM from g2pM import G2pM
from g2pw import G2PWConverter
from pypinyin import lazy_pinyin from pypinyin import lazy_pinyin
from pypinyin import load_phrases_dict from pypinyin import load_phrases_dict
from pypinyin import load_single_dict from pypinyin import load_single_dict
from pypinyin import Style from pypinyin import Style
from pypinyin_dict.phrase_pinyin_data import large_pinyin from pypinyin_dict.phrase_pinyin_data import large_pinyin
from paddlespeech.t2s.frontend.g2pw import G2PWOnnxConverter
from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon
from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi
from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer
@ -56,8 +57,9 @@ def insert_after_character(lst, item):
class Polyphonic(): class Polyphonic():
def __init__(self,dict_file="./paddlespeech/t2s/frontend/polyphonic.yaml"): def __init__(self):
with open(dict_file, encoding='utf8') as polyphonic_file: with open(os.path.join(os.path.dirname(os.path.abspath(__file__)),
'polyphonic.yaml'), 'r',encoding='utf-8') as polyphonic_file:
# 解析yaml # 解析yaml
polyphonic_dict = yaml.load(polyphonic_file, Loader=yaml.FullLoader) polyphonic_dict = yaml.load(polyphonic_file, Loader=yaml.FullLoader)
self.polyphonic_words = polyphonic_dict["polyphonic"] self.polyphonic_words = polyphonic_dict["polyphonic"]
@ -85,7 +87,7 @@ class Frontend():
with_tone=True, with_erhua=False) with_tone=True, with_erhua=False)
elif self.g2p_model == "g2pW": elif self.g2p_model == "g2pW":
self.corrector = Polyphonic() self.corrector = Polyphonic()
self.g2pW_model = G2PWConverter(style='pinyin', enable_non_tradional_chinese=True) self.g2pW_model = G2PWOnnxConverter(style='pinyin', enable_non_tradional_chinese=True)
self.pinyin2phone = generate_lexicon( self.pinyin2phone = generate_lexicon(
with_tone=True, with_erhua=False) with_tone=True, with_erhua=False)
@ -161,24 +163,6 @@ class Frontend():
# If it's not pinyin (possibly punctuation) or no conversion is required # If it's not pinyin (possibly punctuation) or no conversion is required
initials.append(pinyin) initials.append(pinyin)
finals.append(pinyin) finals.append(pinyin)
elif self.g2p_model == "g2pW":
pinyins = self.g2pW_model(word)[0]
if pinyins == [None]:
pinyins = [word]
for pinyin in pinyins:
pinyin = pinyin.replace("u:", "v")
if pinyin in self.pinyin2phone:
initial_final_list = self.pinyin2phone[pinyin].split(" ")
if len(initial_final_list) == 2:
initials.append(initial_final_list[0])
finals.append(initial_final_list[1])
elif len(initial_final_list) == 1:
initials.append('')
finals.append(initial_final_list[1])
else:
# If it's not pinyin (possibly punctuation) or no conversion is required
initials.append(pinyin)
finals.append(pinyin)
return initials, finals return initials, finals
# if merge_sentences, merge all sentences into one phone sequence # if merge_sentences, merge all sentences into one phone sequence

Loading…
Cancel
Save