You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
231 lines
7.1 KiB
231 lines
7.1 KiB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Contains the text featurizer class."""
|
|
from pprint import pformat
|
|
|
|
import sentencepiece as spm
|
|
|
|
from ..utility import BLANK
|
|
from ..utility import EOS
|
|
from ..utility import load_dict
|
|
from ..utility import MASKCTC
|
|
from ..utility import SOS
|
|
from ..utility import SPACE
|
|
from ..utility import UNK
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
__all__ = ["TextFeaturizer"]
|
|
|
|
|
|
class TextFeaturizer():
|
|
def __init__(self,
|
|
unit_type,
|
|
vocab_filepath,
|
|
spm_model_prefix=None,
|
|
maskctc=False):
|
|
"""Text featurizer, for processing or extracting features from text.
|
|
|
|
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
|
|
a list of token indices. Note that the token indexing order follows the
|
|
given vocabulary file.
|
|
|
|
Args:
|
|
unit_type (str): unit type, e.g. char, word, spm
|
|
vocab_filepath (str): Filepath to load vocabulary for token indices conversion.
|
|
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
|
|
"""
|
|
assert unit_type in ('char', 'spm', 'word')
|
|
self.unit_type = unit_type
|
|
self.unk = UNK
|
|
self.maskctc = maskctc
|
|
|
|
if vocab_filepath:
|
|
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file(
|
|
vocab_filepath, maskctc)
|
|
self.vocab_size = len(self.vocab_list)
|
|
|
|
if unit_type == 'spm':
|
|
spm_model = spm_model_prefix + '.model'
|
|
self.sp = spm.SentencePieceProcessor()
|
|
self.sp.Load(spm_model)
|
|
|
|
def tokenize(self, text, replace_space=True):
|
|
if self.unit_type == 'char':
|
|
tokens = self.char_tokenize(text, replace_space)
|
|
elif self.unit_type == 'word':
|
|
tokens = self.word_tokenize(text)
|
|
else: # spm
|
|
tokens = self.spm_tokenize(text)
|
|
return tokens
|
|
|
|
def detokenize(self, tokens):
|
|
if self.unit_type == 'char':
|
|
text = self.char_detokenize(tokens)
|
|
elif self.unit_type == 'word':
|
|
text = self.word_detokenize(tokens)
|
|
else: # spm
|
|
text = self.spm_detokenize(tokens)
|
|
return text
|
|
|
|
def featurize(self, text):
|
|
"""Convert text string to a list of token indices.
|
|
|
|
Args:
|
|
text (str): Text to process.
|
|
|
|
Returns:
|
|
List[int]: List of token indices.
|
|
"""
|
|
tokens = self.tokenize(text)
|
|
ids = []
|
|
for token in tokens:
|
|
token = token if token in self.vocab_dict else self.unk
|
|
ids.append(self.vocab_dict[token])
|
|
return ids
|
|
|
|
def defeaturize(self, idxs):
|
|
"""Convert a list of token indices to text string,
|
|
ignore index after eos_id.
|
|
|
|
Args:
|
|
idxs (List[int]): List of token indices.
|
|
|
|
Returns:
|
|
str: Text.
|
|
"""
|
|
tokens = []
|
|
for idx in idxs:
|
|
if idx == self.eos_id:
|
|
break
|
|
tokens.append(self._id2token[idx])
|
|
text = self.detokenize(tokens)
|
|
return text
|
|
|
|
def char_tokenize(self, text, replace_space=True):
|
|
"""Character tokenizer.
|
|
|
|
Args:
|
|
text (str): text string.
|
|
replace_space (bool): False only used by build_vocab.py.
|
|
|
|
Returns:
|
|
List[str]: tokens.
|
|
"""
|
|
text = text.strip()
|
|
if replace_space:
|
|
text_list = [SPACE if item == " " else item for item in list(text)]
|
|
else:
|
|
text_list = list(text)
|
|
return text_list
|
|
|
|
def char_detokenize(self, tokens):
|
|
"""Character detokenizer.
|
|
|
|
Args:
|
|
tokens (List[str]): tokens.
|
|
|
|
Returns:
|
|
str: text string.
|
|
"""
|
|
tokens = [t.replace(SPACE, " ") for t in tokens]
|
|
return "".join(tokens)
|
|
|
|
def word_tokenize(self, text):
|
|
"""Word tokenizer, separate by <space>."""
|
|
return text.strip().split()
|
|
|
|
def word_detokenize(self, tokens):
|
|
"""Word detokenizer, separate by <space>."""
|
|
return " ".join(tokens)
|
|
|
|
def spm_tokenize(self, text):
|
|
"""spm tokenize.
|
|
|
|
Args:
|
|
text (str): text string.
|
|
|
|
Returns:
|
|
List[str]: sentence pieces str code
|
|
"""
|
|
stats = {"num_empty": 0, "num_filtered": 0}
|
|
|
|
def valid(line):
|
|
return True
|
|
|
|
def encode(l):
|
|
return self.sp.EncodeAsPieces(l)
|
|
|
|
def encode_line(line):
|
|
line = line.strip()
|
|
if len(line) > 0:
|
|
line = encode(line)
|
|
if valid(line):
|
|
return line
|
|
else:
|
|
stats["num_filtered"] += 1
|
|
else:
|
|
stats["num_empty"] += 1
|
|
return None
|
|
|
|
enc_line = encode_line(text)
|
|
return enc_line
|
|
|
|
def spm_detokenize(self, tokens, input_format='piece'):
|
|
"""spm detokenize.
|
|
|
|
Args:
|
|
ids (List[str]): tokens.
|
|
|
|
Returns:
|
|
str: text
|
|
"""
|
|
if input_format == "piece":
|
|
|
|
def decode(l):
|
|
return "".join(self.sp.DecodePieces(l))
|
|
elif input_format == "id":
|
|
|
|
def decode(l):
|
|
return "".join(self.sp.DecodeIds(l))
|
|
|
|
return decode(tokens)
|
|
|
|
def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool):
|
|
"""Load vocabulary from file."""
|
|
vocab_list = load_dict(vocab_filepath, maskctc)
|
|
assert vocab_list is not None
|
|
logger.debug(f"Vocab: {pformat(vocab_list)}")
|
|
|
|
id2token = dict(
|
|
[(idx, token) for (idx, token) in enumerate(vocab_list)])
|
|
token2id = dict(
|
|
[(token, idx) for (idx, token) in enumerate(vocab_list)])
|
|
|
|
blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1
|
|
maskctc_id = vocab_list.index(MASKCTC) if MASKCTC in vocab_list else -1
|
|
unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1
|
|
eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1
|
|
sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1
|
|
space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1
|
|
|
|
logger.info(f"BLANK id: {blank_id}")
|
|
logger.info(f"UNK id: {unk_id}")
|
|
logger.info(f"EOS id: {eos_id}")
|
|
logger.info(f"SOS id: {sos_id}")
|
|
logger.info(f"SPACE id: {space_id}")
|
|
logger.info(f"MASKCTC id: {maskctc_id}")
|
|
return token2id, id2token, vocab_list, unk_id, eos_id, blank_id
|