commit
a75b2a5bab
@ -0,0 +1,2 @@
|
|||||||
|
include paddlespeech/t2s/exps/*.txt
|
||||||
|
include paddlespeech/t2s/frontend/*.yaml
|
@ -0,0 +1,2 @@
|
|||||||
|
from paddlespeech.t2s.frontend.g2pw.onnx_api import G2PWOnnxConverter
|
||||||
|
|
@ -0,0 +1,134 @@
|
|||||||
|
"""
|
||||||
|
Credits
|
||||||
|
This code is modified from https://github.com/GitYCC/g2pW
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map
|
||||||
|
|
||||||
|
ANCHOR_CHAR = '▁'
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_onnx_input(tokenizer, labels, char2phonemes, chars, texts, query_ids, phonemes=None, pos_tags=None,
|
||||||
|
use_mask=False, use_char_phoneme=False, use_pos=False, window_size=None, max_len=512):
|
||||||
|
if window_size is not None:
|
||||||
|
truncated_texts, truncated_query_ids = _truncate_texts(window_size, texts, query_ids)
|
||||||
|
|
||||||
|
input_ids = []
|
||||||
|
token_type_ids = []
|
||||||
|
attention_masks = []
|
||||||
|
phoneme_masks = []
|
||||||
|
char_ids = []
|
||||||
|
position_ids = []
|
||||||
|
|
||||||
|
for idx in range(len(texts)):
|
||||||
|
text = (truncated_texts if window_size else texts)[idx].lower()
|
||||||
|
query_id = (truncated_query_ids if window_size else query_ids)[idx]
|
||||||
|
|
||||||
|
try:
|
||||||
|
tokens, text2token, token2text = tokenize_and_map(tokenizer, text)
|
||||||
|
except Exception:
|
||||||
|
print(f'warning: text "{text}" is invalid')
|
||||||
|
return {}
|
||||||
|
|
||||||
|
text, query_id, tokens, text2token, token2text = _truncate(max_len, text, query_id, tokens, text2token, token2text)
|
||||||
|
|
||||||
|
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
|
||||||
|
|
||||||
|
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||||
|
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||||
|
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||||
|
|
||||||
|
query_char = text[query_id]
|
||||||
|
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
|
||||||
|
if use_mask else [1] * len(labels)
|
||||||
|
char_id = chars.index(query_char)
|
||||||
|
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
|
||||||
|
|
||||||
|
input_ids.append(input_id)
|
||||||
|
token_type_ids.append(token_type_id)
|
||||||
|
attention_masks.append(attention_mask)
|
||||||
|
phoneme_masks.append(phoneme_mask)
|
||||||
|
char_ids.append(char_id)
|
||||||
|
position_ids.append(position_id)
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
'input_ids': np.array(input_ids),
|
||||||
|
'token_type_ids': np.array(token_type_ids),
|
||||||
|
'attention_masks': np.array(attention_masks),
|
||||||
|
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
|
||||||
|
'char_ids': np.array(char_ids),
|
||||||
|
'position_ids': np.array(position_ids),
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _truncate_texts(window_size, texts, query_ids):
|
||||||
|
truncated_texts = []
|
||||||
|
truncated_query_ids = []
|
||||||
|
for text, query_id in zip(texts, query_ids):
|
||||||
|
start = max(0, query_id - window_size // 2)
|
||||||
|
end = min(len(text), query_id + window_size // 2)
|
||||||
|
truncated_text = text[start:end]
|
||||||
|
truncated_texts.append(truncated_text)
|
||||||
|
|
||||||
|
truncated_query_id = query_id - start
|
||||||
|
truncated_query_ids.append(truncated_query_id)
|
||||||
|
return truncated_texts, truncated_query_ids
|
||||||
|
|
||||||
|
def _truncate(max_len, text, query_id, tokens, text2token, token2text):
|
||||||
|
truncate_len = max_len - 2
|
||||||
|
if len(tokens) <= truncate_len:
|
||||||
|
return (text, query_id, tokens, text2token, token2text)
|
||||||
|
|
||||||
|
token_position = text2token[query_id]
|
||||||
|
|
||||||
|
token_start = token_position - truncate_len // 2
|
||||||
|
token_end = token_start + truncate_len
|
||||||
|
font_exceed_dist = -token_start
|
||||||
|
back_exceed_dist = token_end - len(tokens)
|
||||||
|
if font_exceed_dist > 0:
|
||||||
|
token_start += font_exceed_dist
|
||||||
|
token_end += font_exceed_dist
|
||||||
|
elif back_exceed_dist > 0:
|
||||||
|
token_start -= back_exceed_dist
|
||||||
|
token_end -= back_exceed_dist
|
||||||
|
|
||||||
|
start = token2text[token_start][0]
|
||||||
|
end = token2text[token_end - 1][1]
|
||||||
|
|
||||||
|
return (
|
||||||
|
text[start:end],
|
||||||
|
query_id - start,
|
||||||
|
tokens[token_start:token_end],
|
||||||
|
[i - token_start if i is not None else None for i in text2token[start:end]],
|
||||||
|
[(s - start, e - start) for s, e in token2text[token_start:token_end]]
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_data(sent_path, lb_path=None):
|
||||||
|
raw_texts = open(sent_path).read().rstrip().split('\n')
|
||||||
|
query_ids = [raw.index(ANCHOR_CHAR) for raw in raw_texts]
|
||||||
|
texts = [raw.replace(ANCHOR_CHAR, '') for raw in raw_texts]
|
||||||
|
if lb_path is None:
|
||||||
|
return texts, query_ids
|
||||||
|
else:
|
||||||
|
phonemes = open(lb_path).read().rstrip().split('\n')
|
||||||
|
return texts, query_ids, phonemes
|
||||||
|
|
||||||
|
|
||||||
|
def get_phoneme_labels(polyphonic_chars):
|
||||||
|
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
|
||||||
|
char2phonemes = {}
|
||||||
|
for char, phoneme in polyphonic_chars:
|
||||||
|
if char not in char2phonemes:
|
||||||
|
char2phonemes[char] = []
|
||||||
|
char2phonemes[char].append(labels.index(phoneme))
|
||||||
|
return labels, char2phonemes
|
||||||
|
|
||||||
|
|
||||||
|
def get_char_phoneme_labels(polyphonic_chars):
|
||||||
|
labels = sorted(list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
|
||||||
|
char2phonemes = {}
|
||||||
|
for char, phoneme in polyphonic_chars:
|
||||||
|
if char not in char2phonemes:
|
||||||
|
char2phonemes[char] = []
|
||||||
|
char2phonemes[char].append(labels.index(f'{char} {phoneme}'))
|
||||||
|
return labels, char2phonemes
|
@ -0,0 +1,143 @@
|
|||||||
|
"""
|
||||||
|
Credits
|
||||||
|
This code is modified from https://github.com/GitYCC/g2pW
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import onnxruntime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from opencc import OpenCC
|
||||||
|
from pypinyin import pinyin, lazy_pinyin, Style
|
||||||
|
from paddlenlp.transformers import BertTokenizer
|
||||||
|
from paddlespeech.utils.env import MODEL_HOME
|
||||||
|
from paddlespeech.t2s.frontend.g2pw.dataset import prepare_data,\
|
||||||
|
prepare_onnx_input,\
|
||||||
|
get_phoneme_labels,\
|
||||||
|
get_char_phoneme_labels
|
||||||
|
from paddlespeech.t2s.frontend.g2pw.utils import load_config
|
||||||
|
from paddlespeech.cli.utils import download_and_decompress
|
||||||
|
from paddlespeech.resource.pretrained_models import g2pw_onnx_models
|
||||||
|
|
||||||
|
|
||||||
|
def predict(session, onnx_input, labels):
|
||||||
|
all_preds = []
|
||||||
|
all_confidences = []
|
||||||
|
probs = session.run([],{"input_ids": onnx_input['input_ids'],
|
||||||
|
"token_type_ids":onnx_input['token_type_ids'],
|
||||||
|
"attention_mask":onnx_input['attention_masks'],
|
||||||
|
"phoneme_mask":onnx_input['phoneme_masks'],
|
||||||
|
"char_ids":onnx_input['char_ids'],
|
||||||
|
"position_ids":onnx_input['position_ids']})[0]
|
||||||
|
|
||||||
|
preds = np.argmax(probs,axis=1).tolist()
|
||||||
|
max_probs = []
|
||||||
|
for index,arr in zip(preds,probs.tolist()):
|
||||||
|
max_probs.append(arr[index])
|
||||||
|
all_preds += [labels[pred] for pred in preds]
|
||||||
|
all_confidences += max_probs
|
||||||
|
|
||||||
|
return all_preds, all_confidences
|
||||||
|
|
||||||
|
|
||||||
|
class G2PWOnnxConverter:
|
||||||
|
def __init__(self, model_dir = MODEL_HOME, style='bopomofo', model_source=None, enable_non_tradional_chinese=False):
|
||||||
|
if not os.path.exists(os.path.join(model_dir, 'G2PWModel/g2pW.onnx')):
|
||||||
|
uncompress_path = download_and_decompress(g2pw_onnx_models['G2PWModel']['1.0'],model_dir)
|
||||||
|
|
||||||
|
sess_options = onnxruntime.SessionOptions()
|
||||||
|
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
||||||
|
sess_options.intra_op_num_threads = 2
|
||||||
|
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(model_dir, 'G2PWModel/g2pW.onnx'),sess_options=sess_options)
|
||||||
|
self.config = load_config(os.path.join(model_dir, 'G2PWModel/config.py'), use_default=True)
|
||||||
|
|
||||||
|
self.model_source = model_source if model_source else self.config.model_source
|
||||||
|
self.enable_opencc = enable_non_tradional_chinese
|
||||||
|
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained(self.config.model_source)
|
||||||
|
|
||||||
|
polyphonic_chars_path = os.path.join(model_dir, 'G2PWModel/POLYPHONIC_CHARS.txt')
|
||||||
|
monophonic_chars_path = os.path.join(model_dir, 'G2PWModel/MONOPHONIC_CHARS.txt')
|
||||||
|
self.polyphonic_chars = [line.split('\t') for line in open(polyphonic_chars_path,encoding='utf-8').read().strip().split('\n')]
|
||||||
|
self.monophonic_chars = [line.split('\t') for line in open(monophonic_chars_path,encoding='utf-8').read().strip().split('\n')]
|
||||||
|
self.labels, self.char2phonemes = get_char_phoneme_labels(self.polyphonic_chars) if self.config.use_char_phoneme else get_phoneme_labels(self.polyphonic_chars)
|
||||||
|
|
||||||
|
self.chars = sorted(list(self.char2phonemes.keys()))
|
||||||
|
self.pos_tags = ['UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI']
|
||||||
|
|
||||||
|
with open(os.path.join(model_dir,'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'), 'r',encoding='utf-8') as fr:
|
||||||
|
self.bopomofo_convert_dict = json.load(fr)
|
||||||
|
self.style_convert_func = {
|
||||||
|
'bopomofo': lambda x: x,
|
||||||
|
'pinyin': self._convert_bopomofo_to_pinyin,
|
||||||
|
}[style]
|
||||||
|
|
||||||
|
with open(os.path.join(model_dir,'G2PWModel/char_bopomofo_dict.json'), 'r',encoding='utf-8') as fr:
|
||||||
|
self.char_bopomofo_dict = json.load(fr)
|
||||||
|
|
||||||
|
if self.enable_opencc:
|
||||||
|
self.cc = OpenCC('s2tw')
|
||||||
|
|
||||||
|
def _convert_bopomofo_to_pinyin(self, bopomofo):
|
||||||
|
tone = bopomofo[-1]
|
||||||
|
assert tone in '12345'
|
||||||
|
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
|
||||||
|
if component:
|
||||||
|
return component + tone
|
||||||
|
else:
|
||||||
|
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __call__(self, sentences):
|
||||||
|
if isinstance(sentences, str):
|
||||||
|
sentences = [sentences]
|
||||||
|
|
||||||
|
if self.enable_opencc:
|
||||||
|
translated_sentences = []
|
||||||
|
for sent in sentences:
|
||||||
|
translated_sent = self.cc.convert(sent)
|
||||||
|
assert len(translated_sent) == len(sent)
|
||||||
|
translated_sentences.append(translated_sent)
|
||||||
|
sentences = translated_sentences
|
||||||
|
|
||||||
|
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences)
|
||||||
|
if len(texts) == 0:
|
||||||
|
# sentences no polyphonic words
|
||||||
|
return partial_results
|
||||||
|
|
||||||
|
onnx_input = prepare_onnx_input(self.tokenizer, self.labels, self.char2phonemes, self.chars, texts, query_ids,
|
||||||
|
use_mask=self.config.use_mask, use_char_phoneme=self.config.use_char_phoneme,
|
||||||
|
window_size=None)
|
||||||
|
|
||||||
|
preds, confidences = predict(self.session_g2pW, onnx_input, self.labels)
|
||||||
|
if self.config.use_char_phoneme:
|
||||||
|
preds = [pred.split(' ')[1] for pred in preds]
|
||||||
|
|
||||||
|
results = partial_results
|
||||||
|
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
|
||||||
|
results[sent_id][query_id] = self.style_convert_func(pred)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _prepare_data(self, sentences):
|
||||||
|
polyphonic_chars = set(self.chars)
|
||||||
|
monophonic_chars_dict = {
|
||||||
|
char: phoneme for char, phoneme in self.monophonic_chars
|
||||||
|
}
|
||||||
|
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
||||||
|
for sent_id, sent in enumerate(sentences):
|
||||||
|
pypinyin_result = pinyin(sent,style=Style.TONE3)
|
||||||
|
partial_result = [None] * len(sent)
|
||||||
|
for i, char in enumerate(sent):
|
||||||
|
if char in polyphonic_chars:
|
||||||
|
texts.append(sent)
|
||||||
|
query_ids.append(i)
|
||||||
|
sent_ids.append(sent_id)
|
||||||
|
elif char in monophonic_chars_dict:
|
||||||
|
partial_result[i] = self.style_convert_func(monophonic_chars_dict[char])
|
||||||
|
elif char in self.char_bopomofo_dict:
|
||||||
|
partial_result[i] = pypinyin_result[i][0]
|
||||||
|
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
|
||||||
|
partial_results.append(partial_result)
|
||||||
|
return texts, query_ids, sent_ids, partial_results
|
@ -0,0 +1,133 @@
|
|||||||
|
|
||||||
|
"""
|
||||||
|
Credits
|
||||||
|
This code is modified from https://github.com/GitYCC/g2pW
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def wordize_and_map(text):
|
||||||
|
words = []
|
||||||
|
index_map_from_text_to_word = []
|
||||||
|
index_map_from_word_to_text = []
|
||||||
|
while len(text) > 0:
|
||||||
|
match_space = re.match(r'^ +', text)
|
||||||
|
if match_space:
|
||||||
|
space_str = match_space.group(0)
|
||||||
|
index_map_from_text_to_word += [None] * len(space_str)
|
||||||
|
text = text[len(space_str):]
|
||||||
|
continue
|
||||||
|
|
||||||
|
match_en = re.match(r'^[a-zA-Z0-9]+', text)
|
||||||
|
if match_en:
|
||||||
|
en_word = match_en.group(0)
|
||||||
|
|
||||||
|
word_start_pos = len(index_map_from_text_to_word)
|
||||||
|
word_end_pos = word_start_pos + len(en_word)
|
||||||
|
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
|
||||||
|
|
||||||
|
index_map_from_text_to_word += [len(words)] * len(en_word)
|
||||||
|
|
||||||
|
words.append(en_word)
|
||||||
|
text = text[len(en_word):]
|
||||||
|
else:
|
||||||
|
word_start_pos = len(index_map_from_text_to_word)
|
||||||
|
word_end_pos = word_start_pos + 1
|
||||||
|
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
|
||||||
|
|
||||||
|
index_map_from_text_to_word += [len(words)]
|
||||||
|
|
||||||
|
words.append(text[0])
|
||||||
|
text = text[1:]
|
||||||
|
return words, index_map_from_text_to_word, index_map_from_word_to_text
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize_and_map(tokenizer, text):
|
||||||
|
words, text2word, word2text = wordize_and_map(text)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
index_map_from_token_to_text = []
|
||||||
|
for word, (word_start, word_end) in zip(words, word2text):
|
||||||
|
word_tokens = tokenizer.tokenize(word)
|
||||||
|
|
||||||
|
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
|
||||||
|
index_map_from_token_to_text.append((word_start, word_end))
|
||||||
|
tokens.append('[UNK]')
|
||||||
|
else:
|
||||||
|
current_word_start = word_start
|
||||||
|
for word_token in word_tokens:
|
||||||
|
word_token_len = len(re.sub(r'^##', '', word_token))
|
||||||
|
index_map_from_token_to_text.append(
|
||||||
|
(current_word_start, current_word_start + word_token_len))
|
||||||
|
current_word_start = current_word_start + word_token_len
|
||||||
|
tokens.append(word_token)
|
||||||
|
|
||||||
|
index_map_from_text_to_token = text2word
|
||||||
|
for i, (token_start, token_end) in enumerate(index_map_from_token_to_text):
|
||||||
|
for token_pos in range(token_start, token_end):
|
||||||
|
index_map_from_text_to_token[token_pos] = i
|
||||||
|
|
||||||
|
return tokens, index_map_from_text_to_token, index_map_from_token_to_text
|
||||||
|
|
||||||
|
|
||||||
|
def _load_config(config_path):
|
||||||
|
import importlib.util
|
||||||
|
spec = importlib.util.spec_from_file_location('__init__', config_path)
|
||||||
|
config = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(config)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
default_config_dict = {
|
||||||
|
'manual_seed': 1313,
|
||||||
|
'model_source': 'bert-base-chinese',
|
||||||
|
'window_size': 32,
|
||||||
|
'num_workers': 2,
|
||||||
|
'use_mask': True,
|
||||||
|
'use_char_phoneme': False,
|
||||||
|
'use_conditional': True,
|
||||||
|
'param_conditional': {
|
||||||
|
'affect_location': 'softmax',
|
||||||
|
'bias': True,
|
||||||
|
'char-linear': True,
|
||||||
|
'pos-linear': False,
|
||||||
|
'char+pos-second': True,
|
||||||
|
|
||||||
|
'char+pos-second_lowrank': False,
|
||||||
|
'lowrank_size': 0,
|
||||||
|
'char+pos-second_fm': False,
|
||||||
|
'fm_size': 0,
|
||||||
|
'fix_mode': None,
|
||||||
|
'count_json': 'train.count.json'
|
||||||
|
},
|
||||||
|
'lr': 5e-5,
|
||||||
|
'val_interval': 200,
|
||||||
|
'num_iter': 10000,
|
||||||
|
'use_focal': False,
|
||||||
|
'param_focal': {
|
||||||
|
'alpha': 0.0,
|
||||||
|
'gamma': 0.7
|
||||||
|
},
|
||||||
|
'use_pos': True,
|
||||||
|
'param_pos ': {
|
||||||
|
'weight': 0.1,
|
||||||
|
'pos_joint_training': True,
|
||||||
|
'train_pos_path': 'train.pos',
|
||||||
|
'valid_pos_path': 'dev.pos',
|
||||||
|
'test_pos_path': 'test.pos'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_path, use_default=False):
|
||||||
|
config = _load_config(config_path)
|
||||||
|
if use_default:
|
||||||
|
for attr, val in default_config_dict.items():
|
||||||
|
if not hasattr(config, attr):
|
||||||
|
setattr(config, attr, val)
|
||||||
|
elif isinstance(val, dict):
|
||||||
|
d = getattr(config, attr)
|
||||||
|
for dict_k, dict_v in val.items():
|
||||||
|
if dict_k not in d:
|
||||||
|
d[dict_k] = dict_v
|
||||||
|
return config
|
@ -0,0 +1,26 @@
|
|||||||
|
polyphonic:
|
||||||
|
湖泊: ['hu2','po1']
|
||||||
|
地壳: ['di4','qiao4']
|
||||||
|
柏树: ['bai3','shu4']
|
||||||
|
曝光: ['bao4','guang1']
|
||||||
|
弹力: ['tan2','li4']
|
||||||
|
字帖: ['zi4','tie4']
|
||||||
|
口吃: ['kou3','chi1']
|
||||||
|
包扎: ['bao1','za1']
|
||||||
|
哪吒: ['ne2','zha1']
|
||||||
|
说服: ['shuo1','fu2']
|
||||||
|
识字: ['shi2','zi4']
|
||||||
|
骨头: ['gu3','tou5']
|
||||||
|
对称: ['dui4','chen4']
|
||||||
|
口供: ['kou3','gong4']
|
||||||
|
抹布: ['ma1','bu4']
|
||||||
|
露背: ['lu4','bei4']
|
||||||
|
圈养: ['juan4', 'yang3']
|
||||||
|
眼眶: ['yan3', 'kuang4']
|
||||||
|
品行: ['pin3','xing2']
|
||||||
|
颤抖: ['chan4','dou3']
|
||||||
|
差不多: ['cha4','bu5','duo1']
|
||||||
|
鸭绿江: ['ya1','lu4','jiang1']
|
||||||
|
撒切尔: ['sa4','qie4','er3']
|
||||||
|
比比皆是: ['bi3','bi3','jie1','shi4']
|
||||||
|
身无长物: ['shen1','wu2','chang2','wu4']
|
Loading…
Reference in new issue