commit
a75b2a5bab
@ -0,0 +1,2 @@
|
||||
include paddlespeech/t2s/exps/*.txt
|
||||
include paddlespeech/t2s/frontend/*.yaml
|
@ -0,0 +1,2 @@
|
||||
from paddlespeech.t2s.frontend.g2pw.onnx_api import G2PWOnnxConverter
|
||||
|
@ -0,0 +1,134 @@
|
||||
"""
|
||||
Credits
|
||||
This code is modified from https://github.com/GitYCC/g2pW
|
||||
"""
|
||||
import numpy as np
|
||||
from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map
|
||||
|
||||
ANCHOR_CHAR = '▁'
|
||||
|
||||
|
||||
def prepare_onnx_input(tokenizer, labels, char2phonemes, chars, texts, query_ids, phonemes=None, pos_tags=None,
|
||||
use_mask=False, use_char_phoneme=False, use_pos=False, window_size=None, max_len=512):
|
||||
if window_size is not None:
|
||||
truncated_texts, truncated_query_ids = _truncate_texts(window_size, texts, query_ids)
|
||||
|
||||
input_ids = []
|
||||
token_type_ids = []
|
||||
attention_masks = []
|
||||
phoneme_masks = []
|
||||
char_ids = []
|
||||
position_ids = []
|
||||
|
||||
for idx in range(len(texts)):
|
||||
text = (truncated_texts if window_size else texts)[idx].lower()
|
||||
query_id = (truncated_query_ids if window_size else query_ids)[idx]
|
||||
|
||||
try:
|
||||
tokens, text2token, token2text = tokenize_and_map(tokenizer, text)
|
||||
except Exception:
|
||||
print(f'warning: text "{text}" is invalid')
|
||||
return {}
|
||||
|
||||
text, query_id, tokens, text2token, token2text = _truncate(max_len, text, query_id, tokens, text2token, token2text)
|
||||
|
||||
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
|
||||
|
||||
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||
|
||||
query_char = text[query_id]
|
||||
phoneme_mask = [1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] \
|
||||
if use_mask else [1] * len(labels)
|
||||
char_id = chars.index(query_char)
|
||||
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
|
||||
|
||||
input_ids.append(input_id)
|
||||
token_type_ids.append(token_type_id)
|
||||
attention_masks.append(attention_mask)
|
||||
phoneme_masks.append(phoneme_mask)
|
||||
char_ids.append(char_id)
|
||||
position_ids.append(position_id)
|
||||
|
||||
outputs = {
|
||||
'input_ids': np.array(input_ids),
|
||||
'token_type_ids': np.array(token_type_ids),
|
||||
'attention_masks': np.array(attention_masks),
|
||||
'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
|
||||
'char_ids': np.array(char_ids),
|
||||
'position_ids': np.array(position_ids),
|
||||
}
|
||||
return outputs
|
||||
|
||||
def _truncate_texts(window_size, texts, query_ids):
|
||||
truncated_texts = []
|
||||
truncated_query_ids = []
|
||||
for text, query_id in zip(texts, query_ids):
|
||||
start = max(0, query_id - window_size // 2)
|
||||
end = min(len(text), query_id + window_size // 2)
|
||||
truncated_text = text[start:end]
|
||||
truncated_texts.append(truncated_text)
|
||||
|
||||
truncated_query_id = query_id - start
|
||||
truncated_query_ids.append(truncated_query_id)
|
||||
return truncated_texts, truncated_query_ids
|
||||
|
||||
def _truncate(max_len, text, query_id, tokens, text2token, token2text):
|
||||
truncate_len = max_len - 2
|
||||
if len(tokens) <= truncate_len:
|
||||
return (text, query_id, tokens, text2token, token2text)
|
||||
|
||||
token_position = text2token[query_id]
|
||||
|
||||
token_start = token_position - truncate_len // 2
|
||||
token_end = token_start + truncate_len
|
||||
font_exceed_dist = -token_start
|
||||
back_exceed_dist = token_end - len(tokens)
|
||||
if font_exceed_dist > 0:
|
||||
token_start += font_exceed_dist
|
||||
token_end += font_exceed_dist
|
||||
elif back_exceed_dist > 0:
|
||||
token_start -= back_exceed_dist
|
||||
token_end -= back_exceed_dist
|
||||
|
||||
start = token2text[token_start][0]
|
||||
end = token2text[token_end - 1][1]
|
||||
|
||||
return (
|
||||
text[start:end],
|
||||
query_id - start,
|
||||
tokens[token_start:token_end],
|
||||
[i - token_start if i is not None else None for i in text2token[start:end]],
|
||||
[(s - start, e - start) for s, e in token2text[token_start:token_end]]
|
||||
)
|
||||
|
||||
def prepare_data(sent_path, lb_path=None):
|
||||
raw_texts = open(sent_path).read().rstrip().split('\n')
|
||||
query_ids = [raw.index(ANCHOR_CHAR) for raw in raw_texts]
|
||||
texts = [raw.replace(ANCHOR_CHAR, '') for raw in raw_texts]
|
||||
if lb_path is None:
|
||||
return texts, query_ids
|
||||
else:
|
||||
phonemes = open(lb_path).read().rstrip().split('\n')
|
||||
return texts, query_ids, phonemes
|
||||
|
||||
|
||||
def get_phoneme_labels(polyphonic_chars):
|
||||
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
|
||||
char2phonemes = {}
|
||||
for char, phoneme in polyphonic_chars:
|
||||
if char not in char2phonemes:
|
||||
char2phonemes[char] = []
|
||||
char2phonemes[char].append(labels.index(phoneme))
|
||||
return labels, char2phonemes
|
||||
|
||||
|
||||
def get_char_phoneme_labels(polyphonic_chars):
|
||||
labels = sorted(list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
|
||||
char2phonemes = {}
|
||||
for char, phoneme in polyphonic_chars:
|
||||
if char not in char2phonemes:
|
||||
char2phonemes[char] = []
|
||||
char2phonemes[char].append(labels.index(f'{char} {phoneme}'))
|
||||
return labels, char2phonemes
|
@ -0,0 +1,143 @@
|
||||
"""
|
||||
Credits
|
||||
This code is modified from https://github.com/GitYCC/g2pW
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import onnxruntime
|
||||
import numpy as np
|
||||
|
||||
from opencc import OpenCC
|
||||
from pypinyin import pinyin, lazy_pinyin, Style
|
||||
from paddlenlp.transformers import BertTokenizer
|
||||
from paddlespeech.utils.env import MODEL_HOME
|
||||
from paddlespeech.t2s.frontend.g2pw.dataset import prepare_data,\
|
||||
prepare_onnx_input,\
|
||||
get_phoneme_labels,\
|
||||
get_char_phoneme_labels
|
||||
from paddlespeech.t2s.frontend.g2pw.utils import load_config
|
||||
from paddlespeech.cli.utils import download_and_decompress
|
||||
from paddlespeech.resource.pretrained_models import g2pw_onnx_models
|
||||
|
||||
|
||||
def predict(session, onnx_input, labels):
|
||||
all_preds = []
|
||||
all_confidences = []
|
||||
probs = session.run([],{"input_ids": onnx_input['input_ids'],
|
||||
"token_type_ids":onnx_input['token_type_ids'],
|
||||
"attention_mask":onnx_input['attention_masks'],
|
||||
"phoneme_mask":onnx_input['phoneme_masks'],
|
||||
"char_ids":onnx_input['char_ids'],
|
||||
"position_ids":onnx_input['position_ids']})[0]
|
||||
|
||||
preds = np.argmax(probs,axis=1).tolist()
|
||||
max_probs = []
|
||||
for index,arr in zip(preds,probs.tolist()):
|
||||
max_probs.append(arr[index])
|
||||
all_preds += [labels[pred] for pred in preds]
|
||||
all_confidences += max_probs
|
||||
|
||||
return all_preds, all_confidences
|
||||
|
||||
|
||||
class G2PWOnnxConverter:
|
||||
def __init__(self, model_dir = MODEL_HOME, style='bopomofo', model_source=None, enable_non_tradional_chinese=False):
|
||||
if not os.path.exists(os.path.join(model_dir, 'G2PWModel/g2pW.onnx')):
|
||||
uncompress_path = download_and_decompress(g2pw_onnx_models['G2PWModel']['1.0'],model_dir)
|
||||
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
||||
sess_options.intra_op_num_threads = 2
|
||||
self.session_g2pW = onnxruntime.InferenceSession(os.path.join(model_dir, 'G2PWModel/g2pW.onnx'),sess_options=sess_options)
|
||||
self.config = load_config(os.path.join(model_dir, 'G2PWModel/config.py'), use_default=True)
|
||||
|
||||
self.model_source = model_source if model_source else self.config.model_source
|
||||
self.enable_opencc = enable_non_tradional_chinese
|
||||
|
||||
self.tokenizer = BertTokenizer.from_pretrained(self.config.model_source)
|
||||
|
||||
polyphonic_chars_path = os.path.join(model_dir, 'G2PWModel/POLYPHONIC_CHARS.txt')
|
||||
monophonic_chars_path = os.path.join(model_dir, 'G2PWModel/MONOPHONIC_CHARS.txt')
|
||||
self.polyphonic_chars = [line.split('\t') for line in open(polyphonic_chars_path,encoding='utf-8').read().strip().split('\n')]
|
||||
self.monophonic_chars = [line.split('\t') for line in open(monophonic_chars_path,encoding='utf-8').read().strip().split('\n')]
|
||||
self.labels, self.char2phonemes = get_char_phoneme_labels(self.polyphonic_chars) if self.config.use_char_phoneme else get_phoneme_labels(self.polyphonic_chars)
|
||||
|
||||
self.chars = sorted(list(self.char2phonemes.keys()))
|
||||
self.pos_tags = ['UNK', 'A', 'C', 'D', 'I', 'N', 'P', 'T', 'V', 'DE', 'SHI']
|
||||
|
||||
with open(os.path.join(model_dir,'G2PWModel/bopomofo_to_pinyin_wo_tune_dict.json'), 'r',encoding='utf-8') as fr:
|
||||
self.bopomofo_convert_dict = json.load(fr)
|
||||
self.style_convert_func = {
|
||||
'bopomofo': lambda x: x,
|
||||
'pinyin': self._convert_bopomofo_to_pinyin,
|
||||
}[style]
|
||||
|
||||
with open(os.path.join(model_dir,'G2PWModel/char_bopomofo_dict.json'), 'r',encoding='utf-8') as fr:
|
||||
self.char_bopomofo_dict = json.load(fr)
|
||||
|
||||
if self.enable_opencc:
|
||||
self.cc = OpenCC('s2tw')
|
||||
|
||||
def _convert_bopomofo_to_pinyin(self, bopomofo):
|
||||
tone = bopomofo[-1]
|
||||
assert tone in '12345'
|
||||
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
|
||||
if component:
|
||||
return component + tone
|
||||
else:
|
||||
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
|
||||
return None
|
||||
|
||||
def __call__(self, sentences):
|
||||
if isinstance(sentences, str):
|
||||
sentences = [sentences]
|
||||
|
||||
if self.enable_opencc:
|
||||
translated_sentences = []
|
||||
for sent in sentences:
|
||||
translated_sent = self.cc.convert(sent)
|
||||
assert len(translated_sent) == len(sent)
|
||||
translated_sentences.append(translated_sent)
|
||||
sentences = translated_sentences
|
||||
|
||||
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences)
|
||||
if len(texts) == 0:
|
||||
# sentences no polyphonic words
|
||||
return partial_results
|
||||
|
||||
onnx_input = prepare_onnx_input(self.tokenizer, self.labels, self.char2phonemes, self.chars, texts, query_ids,
|
||||
use_mask=self.config.use_mask, use_char_phoneme=self.config.use_char_phoneme,
|
||||
window_size=None)
|
||||
|
||||
preds, confidences = predict(self.session_g2pW, onnx_input, self.labels)
|
||||
if self.config.use_char_phoneme:
|
||||
preds = [pred.split(' ')[1] for pred in preds]
|
||||
|
||||
results = partial_results
|
||||
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
|
||||
results[sent_id][query_id] = self.style_convert_func(pred)
|
||||
|
||||
return results
|
||||
|
||||
def _prepare_data(self, sentences):
|
||||
polyphonic_chars = set(self.chars)
|
||||
monophonic_chars_dict = {
|
||||
char: phoneme for char, phoneme in self.monophonic_chars
|
||||
}
|
||||
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
||||
for sent_id, sent in enumerate(sentences):
|
||||
pypinyin_result = pinyin(sent,style=Style.TONE3)
|
||||
partial_result = [None] * len(sent)
|
||||
for i, char in enumerate(sent):
|
||||
if char in polyphonic_chars:
|
||||
texts.append(sent)
|
||||
query_ids.append(i)
|
||||
sent_ids.append(sent_id)
|
||||
elif char in monophonic_chars_dict:
|
||||
partial_result[i] = self.style_convert_func(monophonic_chars_dict[char])
|
||||
elif char in self.char_bopomofo_dict:
|
||||
partial_result[i] = pypinyin_result[i][0]
|
||||
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
|
||||
partial_results.append(partial_result)
|
||||
return texts, query_ids, sent_ids, partial_results
|
@ -0,0 +1,133 @@
|
||||
|
||||
"""
|
||||
Credits
|
||||
This code is modified from https://github.com/GitYCC/g2pW
|
||||
"""
|
||||
import re
|
||||
import sys
|
||||
|
||||
def wordize_and_map(text):
|
||||
words = []
|
||||
index_map_from_text_to_word = []
|
||||
index_map_from_word_to_text = []
|
||||
while len(text) > 0:
|
||||
match_space = re.match(r'^ +', text)
|
||||
if match_space:
|
||||
space_str = match_space.group(0)
|
||||
index_map_from_text_to_word += [None] * len(space_str)
|
||||
text = text[len(space_str):]
|
||||
continue
|
||||
|
||||
match_en = re.match(r'^[a-zA-Z0-9]+', text)
|
||||
if match_en:
|
||||
en_word = match_en.group(0)
|
||||
|
||||
word_start_pos = len(index_map_from_text_to_word)
|
||||
word_end_pos = word_start_pos + len(en_word)
|
||||
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
|
||||
|
||||
index_map_from_text_to_word += [len(words)] * len(en_word)
|
||||
|
||||
words.append(en_word)
|
||||
text = text[len(en_word):]
|
||||
else:
|
||||
word_start_pos = len(index_map_from_text_to_word)
|
||||
word_end_pos = word_start_pos + 1
|
||||
index_map_from_word_to_text.append((word_start_pos, word_end_pos))
|
||||
|
||||
index_map_from_text_to_word += [len(words)]
|
||||
|
||||
words.append(text[0])
|
||||
text = text[1:]
|
||||
return words, index_map_from_text_to_word, index_map_from_word_to_text
|
||||
|
||||
|
||||
def tokenize_and_map(tokenizer, text):
|
||||
words, text2word, word2text = wordize_and_map(text)
|
||||
|
||||
tokens = []
|
||||
index_map_from_token_to_text = []
|
||||
for word, (word_start, word_end) in zip(words, word2text):
|
||||
word_tokens = tokenizer.tokenize(word)
|
||||
|
||||
if len(word_tokens) == 0 or word_tokens == ['[UNK]']:
|
||||
index_map_from_token_to_text.append((word_start, word_end))
|
||||
tokens.append('[UNK]')
|
||||
else:
|
||||
current_word_start = word_start
|
||||
for word_token in word_tokens:
|
||||
word_token_len = len(re.sub(r'^##', '', word_token))
|
||||
index_map_from_token_to_text.append(
|
||||
(current_word_start, current_word_start + word_token_len))
|
||||
current_word_start = current_word_start + word_token_len
|
||||
tokens.append(word_token)
|
||||
|
||||
index_map_from_text_to_token = text2word
|
||||
for i, (token_start, token_end) in enumerate(index_map_from_token_to_text):
|
||||
for token_pos in range(token_start, token_end):
|
||||
index_map_from_text_to_token[token_pos] = i
|
||||
|
||||
return tokens, index_map_from_text_to_token, index_map_from_token_to_text
|
||||
|
||||
|
||||
def _load_config(config_path):
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location('__init__', config_path)
|
||||
config = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(config)
|
||||
return config
|
||||
|
||||
|
||||
default_config_dict = {
|
||||
'manual_seed': 1313,
|
||||
'model_source': 'bert-base-chinese',
|
||||
'window_size': 32,
|
||||
'num_workers': 2,
|
||||
'use_mask': True,
|
||||
'use_char_phoneme': False,
|
||||
'use_conditional': True,
|
||||
'param_conditional': {
|
||||
'affect_location': 'softmax',
|
||||
'bias': True,
|
||||
'char-linear': True,
|
||||
'pos-linear': False,
|
||||
'char+pos-second': True,
|
||||
|
||||
'char+pos-second_lowrank': False,
|
||||
'lowrank_size': 0,
|
||||
'char+pos-second_fm': False,
|
||||
'fm_size': 0,
|
||||
'fix_mode': None,
|
||||
'count_json': 'train.count.json'
|
||||
},
|
||||
'lr': 5e-5,
|
||||
'val_interval': 200,
|
||||
'num_iter': 10000,
|
||||
'use_focal': False,
|
||||
'param_focal': {
|
||||
'alpha': 0.0,
|
||||
'gamma': 0.7
|
||||
},
|
||||
'use_pos': True,
|
||||
'param_pos ': {
|
||||
'weight': 0.1,
|
||||
'pos_joint_training': True,
|
||||
'train_pos_path': 'train.pos',
|
||||
'valid_pos_path': 'dev.pos',
|
||||
'test_pos_path': 'test.pos'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def load_config(config_path, use_default=False):
|
||||
config = _load_config(config_path)
|
||||
if use_default:
|
||||
for attr, val in default_config_dict.items():
|
||||
if not hasattr(config, attr):
|
||||
setattr(config, attr, val)
|
||||
elif isinstance(val, dict):
|
||||
d = getattr(config, attr)
|
||||
for dict_k, dict_v in val.items():
|
||||
if dict_k not in d:
|
||||
d[dict_k] = dict_v
|
||||
return config
|
@ -0,0 +1,26 @@
|
||||
polyphonic:
|
||||
湖泊: ['hu2','po1']
|
||||
地壳: ['di4','qiao4']
|
||||
柏树: ['bai3','shu4']
|
||||
曝光: ['bao4','guang1']
|
||||
弹力: ['tan2','li4']
|
||||
字帖: ['zi4','tie4']
|
||||
口吃: ['kou3','chi1']
|
||||
包扎: ['bao1','za1']
|
||||
哪吒: ['ne2','zha1']
|
||||
说服: ['shuo1','fu2']
|
||||
识字: ['shi2','zi4']
|
||||
骨头: ['gu3','tou5']
|
||||
对称: ['dui4','chen4']
|
||||
口供: ['kou3','gong4']
|
||||
抹布: ['ma1','bu4']
|
||||
露背: ['lu4','bei4']
|
||||
圈养: ['juan4', 'yang3']
|
||||
眼眶: ['yan3', 'kuang4']
|
||||
品行: ['pin3','xing2']
|
||||
颤抖: ['chan4','dou3']
|
||||
差不多: ['cha4','bu5','duo1']
|
||||
鸭绿江: ['ya1','lu4','jiang1']
|
||||
撒切尔: ['sa4','qie4','er3']
|
||||
比比皆是: ['bi3','bi3','jie1','shi4']
|
||||
身无长物: ['shen1','wu2','chang2','wu4']
|
Loading…
Reference in new issue