|
|
|
@ -15,6 +15,10 @@
|
|
|
|
|
Credits
|
|
|
|
|
This code is modified from https://github.com/GitYCC/g2pW
|
|
|
|
|
"""
|
|
|
|
|
from typing import Dict
|
|
|
|
|
from typing import List
|
|
|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from paddlespeech.t2s.frontend.g2pw.utils import tokenize_and_map
|
|
|
|
@ -23,22 +27,17 @@ ANCHOR_CHAR = '▁'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_onnx_input(tokenizer,
|
|
|
|
|
labels,
|
|
|
|
|
char2phonemes,
|
|
|
|
|
chars,
|
|
|
|
|
texts,
|
|
|
|
|
query_ids,
|
|
|
|
|
phonemes=None,
|
|
|
|
|
pos_tags=None,
|
|
|
|
|
use_mask=False,
|
|
|
|
|
use_char_phoneme=False,
|
|
|
|
|
use_pos=False,
|
|
|
|
|
window_size=None,
|
|
|
|
|
max_len=512):
|
|
|
|
|
labels: List[str],
|
|
|
|
|
char2phonemes: Dict[str, List[int]],
|
|
|
|
|
chars: List[str],
|
|
|
|
|
texts: List[str],
|
|
|
|
|
query_ids: List[int],
|
|
|
|
|
use_mask: bool=False,
|
|
|
|
|
window_size: int=None,
|
|
|
|
|
max_len: int=512) -> Dict[str, np.array]:
|
|
|
|
|
if window_size is not None:
|
|
|
|
|
truncated_texts, truncated_query_ids = _truncate_texts(window_size,
|
|
|
|
|
texts, query_ids)
|
|
|
|
|
|
|
|
|
|
truncated_texts, truncated_query_ids = _truncate_texts(
|
|
|
|
|
window_size=window_size, texts=texts, query_ids=query_ids)
|
|
|
|
|
input_ids = []
|
|
|
|
|
token_type_ids = []
|
|
|
|
|
attention_masks = []
|
|
|
|
@ -51,13 +50,19 @@ def prepare_onnx_input(tokenizer,
|
|
|
|
|
query_id = (truncated_query_ids if window_size else query_ids)[idx]
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
tokens, text2token, token2text = tokenize_and_map(tokenizer, text)
|
|
|
|
|
tokens, text2token, token2text = tokenize_and_map(
|
|
|
|
|
tokenizer=tokenizer, text=text)
|
|
|
|
|
except Exception:
|
|
|
|
|
print(f'warning: text "{text}" is invalid')
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
text, query_id, tokens, text2token, token2text = _truncate(
|
|
|
|
|
max_len, text, query_id, tokens, text2token, token2text)
|
|
|
|
|
max_len=max_len,
|
|
|
|
|
text=text,
|
|
|
|
|
query_id=query_id,
|
|
|
|
|
tokens=tokens,
|
|
|
|
|
text2token=text2token,
|
|
|
|
|
token2text=token2text)
|
|
|
|
|
|
|
|
|
|
processed_tokens = ['[CLS]'] + tokens + ['[SEP]']
|
|
|
|
|
|
|
|
|
@ -91,7 +96,8 @@ def prepare_onnx_input(tokenizer,
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _truncate_texts(window_size, texts, query_ids):
|
|
|
|
|
def _truncate_texts(window_size: int, texts: List[str],
|
|
|
|
|
query_ids: List[int]) -> Tuple[List[str], List[int]]:
|
|
|
|
|
truncated_texts = []
|
|
|
|
|
truncated_query_ids = []
|
|
|
|
|
for text, query_id in zip(texts, query_ids):
|
|
|
|
@ -105,7 +111,12 @@ def _truncate_texts(window_size, texts, query_ids):
|
|
|
|
|
return truncated_texts, truncated_query_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _truncate(max_len, text, query_id, tokens, text2token, token2text):
|
|
|
|
|
def _truncate(max_len: int,
|
|
|
|
|
text: str,
|
|
|
|
|
query_id: int,
|
|
|
|
|
tokens: List[str],
|
|
|
|
|
text2token: List[int],
|
|
|
|
|
token2text: List[Tuple[int]]):
|
|
|
|
|
truncate_len = max_len - 2
|
|
|
|
|
if len(tokens) <= truncate_len:
|
|
|
|
|
return (text, query_id, tokens, text2token, token2text)
|
|
|
|
@ -132,18 +143,8 @@ def _truncate(max_len, text, query_id, tokens, text2token, token2text):
|
|
|
|
|
], [(s - start, e - start) for s, e in token2text[token_start:token_end]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_data(sent_path, lb_path=None):
|
|
|
|
|
raw_texts = open(sent_path).read().rstrip().split('\n')
|
|
|
|
|
query_ids = [raw.index(ANCHOR_CHAR) for raw in raw_texts]
|
|
|
|
|
texts = [raw.replace(ANCHOR_CHAR, '') for raw in raw_texts]
|
|
|
|
|
if lb_path is None:
|
|
|
|
|
return texts, query_ids
|
|
|
|
|
else:
|
|
|
|
|
phonemes = open(lb_path).read().rstrip().split('\n')
|
|
|
|
|
return texts, query_ids, phonemes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_phoneme_labels(polyphonic_chars):
|
|
|
|
|
def get_phoneme_labels(polyphonic_chars: List[List[str]]
|
|
|
|
|
) -> Tuple[List[str], Dict[str, List[int]]]:
|
|
|
|
|
labels = sorted(list(set([phoneme for char, phoneme in polyphonic_chars])))
|
|
|
|
|
char2phonemes = {}
|
|
|
|
|
for char, phoneme in polyphonic_chars:
|
|
|
|
@ -153,7 +154,8 @@ def get_phoneme_labels(polyphonic_chars):
|
|
|
|
|
return labels, char2phonemes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_char_phoneme_labels(polyphonic_chars):
|
|
|
|
|
def get_char_phoneme_labels(polyphonic_chars: List[List[str]]
|
|
|
|
|
) -> Tuple[List[str], Dict[str, List[int]]]:
|
|
|
|
|
labels = sorted(
|
|
|
|
|
list(set([f'{char} {phoneme}' for char, phoneme in polyphonic_chars])))
|
|
|
|
|
char2phonemes = {}
|
|
|
|
|