[Config]clear the u2 decode config for asr ()

* clear the u2 decode config

* rename the vocab_filepath and cmvn_path
pull/1126/head
Jackwaterveg 3 years ago committed by GitHub
parent 7c6ea14028
commit 5b446f6321
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -179,7 +179,7 @@ class ASRExecutor(BaseExecutor):
self.collate_fn_test = SpeechCollator.from_config(self.config)
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self.config.model.input_dim = self.collate_fn_test.feature_size
self.config.model.output_dim = text_feature.vocab_size
@ -192,7 +192,7 @@ class ASRExecutor(BaseExecutor):
res_path, self.config.collator.spm_model_prefix)
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self.config.model.input_dim = self.config.collator.feat_dim
self.config.model.output_dim = text_feature.vocab_size
@ -279,7 +279,7 @@ class ASRExecutor(BaseExecutor):
audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
@ -295,7 +295,7 @@ class ASRExecutor(BaseExecutor):
"""
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
cfg = self.config.decoding
audio = self._inputs["audio"]
@ -321,13 +321,7 @@ class ASRExecutor(BaseExecutor):
audio_len,
text_feature=text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,

@ -180,7 +180,7 @@ class STExecutor(BaseExecutor):
res_path, self.config.collator.spm_model_prefix)
self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self.config.model.input_dim = self.config.collator.feat_dim
self.config.model.output_dim = self.text_feature.vocab_size
@ -292,14 +292,7 @@ class STExecutor(BaseExecutor):
audio_len,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=None,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
word_reward=cfg.word_reward,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,

@ -41,7 +41,7 @@ class DeepSpeech2Tester_hub():
self.audio_file = args.audio_file
self.collate_fn_test = SpeechCollator.from_config(config)
self._text_featurizer = TextFeaturizer(
unit_type=config.collator.unit_type, vocab_filepath=None)
unit_type=config.collator.unit_type, vocab=None)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
result_transcripts = self.model.decode(

@ -286,7 +286,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def __init__(self, config, args):
super().__init__(config, args)
self._text_featurizer = TextFeaturizer(
unit_type=config.collator.unit_type, vocab_filepath=None)
unit_type=config.collator.unit_type, vocab=None)
def ordid2token(self, texts, texts_len):
""" ord() id to chr() chr """

@ -44,7 +44,7 @@ class U2Infer():
self.text_feature = TextFeaturizer(
unit_type=config.collator.unit_type,
vocab_filepath=config.collator.vocab_filepath,
vocab=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
@ -91,13 +91,7 @@ class U2Infer():
ilen,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,

@ -437,7 +437,7 @@ class U2Tester(U2Trainer):
super().__init__(config, args)
self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list
@ -469,13 +469,7 @@ class U2Tester(U2Trainer):
audio_len,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,

@ -393,7 +393,7 @@ class U2Tester(U2Trainer):
super().__init__(config, args)
self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list
@ -425,13 +425,7 @@ class U2Tester(U2Trainer):
audio_len,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,

@ -437,14 +437,7 @@ class U2STTester(U2STTrainer):
audio_len,
text_feature=text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
word_reward=cfg.word_reward,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
@ -475,14 +468,7 @@ class U2STTester(U2STTrainer):
audio_len,
text_feature=text_feature,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
word_reward=cfg.word_reward,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,

@ -55,7 +55,7 @@ class SpeechFeaturizer():
self.text_feature = TextFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
vocab=vocab_filepath,
spm_model_prefix=spm_model_prefix,
maskctc=maskctc)
self.vocab_size = self.text_feature.vocab_size

@ -13,6 +13,7 @@
# limitations under the License.
"""Contains the text featurizer class."""
from pprint import pformat
from typing import Union
import sentencepiece as spm
@ -31,11 +32,7 @@ __all__ = ["TextFeaturizer"]
class TextFeaturizer():
def __init__(self,
unit_type,
vocab_filepath,
spm_model_prefix=None,
maskctc=False):
def __init__(self, unit_type, vocab, spm_model_prefix=None, maskctc=False):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
@ -44,7 +41,7 @@ class TextFeaturizer():
Args:
unit_type (str): unit type, e.g. char, word, spm
vocab_filepath (str): Filepath to load vocabulary for token indices conversion.
vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list.
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
"""
assert unit_type in ('char', 'spm', 'word')
@ -52,12 +49,12 @@ class TextFeaturizer():
self.unk = UNK
self.maskctc = maskctc
if vocab_filepath:
if vocab:
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file(
vocab_filepath, maskctc)
vocab, maskctc)
self.vocab_size = len(self.vocab_list)
else:
logger.warning("TextFeaturizer: not have vocab file.")
logger.warning("TextFeaturizer: not have vocab file or vocab list.")
if unit_type == 'spm':
spm_model = spm_model_prefix + '.model'
@ -207,9 +204,13 @@ class TextFeaturizer():
return decode(tokens)
def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool):
def _load_vocabulary_from_file(self, vocab: Union[str, list],
maskctc: bool):
"""Load vocabulary from file."""
vocab_list = load_dict(vocab_filepath, maskctc)
if isinstance(vocab, list):
vocab_list = vocab
else:
vocab_list = load_dict(vocab, maskctc)
assert vocab_list is not None
logger.debug(f"Vocab: {pformat(vocab_list)}")

@ -42,7 +42,7 @@ class TextCollatorSpm():
assert (vocab_filepath is not None)
self.text_featurizer = TextFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
vocab=vocab_filepath,
spm_model_prefix=spm_model_prefix)
self.eos_id = self.text_featurizer.eos_id
self.blank_id = self.text_featurizer.blank_id

@ -717,13 +717,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
feats_lengths: paddle.Tensor,
text_feature: Dict[str, int],
decoding_method: str,
lang_model_path: str,
beam_alpha: float,
beam_beta: float,
beam_size: int,
cutoff_prob: float,
cutoff_top_n: int,
num_processes: int,
ctc_weight: float=0.0,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
@ -737,13 +731,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoding_method (str): decoding mode, e.g.
'attention', 'ctc_greedy_search',
'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path (str): lm path.
beam_alpha (float): lm weight.
beam_beta (float): length penalty.
beam_size (int): beam size for search
cutoff_prob (float): for prune.
cutoff_top_n (int): for prune.
num_processes (int):
ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0.
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
<0: for decoding, use full chunk.
@ -839,12 +827,13 @@ class U2Model(U2DecodeModel):
def __init__(self, configs: dict):
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs)
model_conf = configs.get('model_conf', dict())
super().__init__(
vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
**configs['model_conf'])
**model_conf)
@classmethod
def _init_from_config(cls, configs: dict):
@ -893,7 +882,7 @@ class U2Model(U2DecodeModel):
**configs['decoder_conf'])
# ctc decoder and ctc loss
model_conf = configs['model_conf']
model_conf = configs.get('model_conf', dict())
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
grad_norm_type = model_conf.get('ctc_grad_norm_type', None)
ctc = CTCDecoder(

@ -522,14 +522,7 @@ class U2STBaseModel(nn.Layer):
feats_lengths: paddle.Tensor,
text_feature: Dict[str, int],
decoding_method: str,
lang_model_path: str,
beam_alpha: float,
beam_beta: float,
beam_size: int,
cutoff_prob: float,
cutoff_top_n: int,
num_processes: int,
ctc_weight: float=0.0,
word_reward: float=0.0,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
@ -543,14 +536,7 @@ class U2STBaseModel(nn.Layer):
decoding_method (str): decoding mode, e.g.
'fullsentence',
'simultaneous'
lang_model_path (str): lm path.
beam_alpha (float): lm weight.
beam_beta (float): length penalty.
beam_size (int): beam size for search
cutoff_prob (float): for prune.
cutoff_top_n (int): for prune.
num_processes (int):
ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0.
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.

@ -168,13 +168,17 @@ class GlobalCMVN():
norm_means=True,
norm_vars=True,
std_floor=1.0e-20):
self.cmvn_path = cmvn_path
# cmvn_path: Option[str, dict]
cmvn = cmvn_path
self.cmvn = cmvn
self.norm_means = norm_means
self.norm_vars = norm_vars
self.std_floor = std_floor
with open(cmvn_path) as f:
cmvn_stats = json.load(f)
if isinstance(cmvn, dict):
cmvn_stats = cmvn
else:
with open(cmvn) as f:
cmvn_stats = json.load(f)
self.count = cmvn_stats['frame_num']
self.mean = np.array(cmvn_stats['mean_stat']) / self.count
self.square_sums = np.array(cmvn_stats['var_stat'])
@ -183,8 +187,8 @@ class GlobalCMVN():
def __repr__(self):
return f"""{self.__class__.__name__}(
cmvn_path={self.cmvn_path},
norm_means={self.norm_means},
cmvn_path={self.cmvn},
norm_means={self.norm_means},
norm_vars={self.norm_vars},)"""
def __call__(self, x, uttid=None):

Loading…
Cancel
Save