[Config]clear the u2 decode config for asr (#1107)

* clear the u2 decode config

* rename the vocab_filepath and cmvn_path
pull/1126/head
Jackwaterveg 4 years ago committed by GitHub
parent 7c6ea14028
commit 5b446f6321
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -179,7 +179,7 @@ class ASRExecutor(BaseExecutor):
self.collate_fn_test = SpeechCollator.from_config(self.config) self.collate_fn_test = SpeechCollator.from_config(self.config)
text_feature = TextFeaturizer( text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
self.config.model.input_dim = self.collate_fn_test.feature_size self.config.model.input_dim = self.collate_fn_test.feature_size
self.config.model.output_dim = text_feature.vocab_size self.config.model.output_dim = text_feature.vocab_size
@ -192,7 +192,7 @@ class ASRExecutor(BaseExecutor):
res_path, self.config.collator.spm_model_prefix) res_path, self.config.collator.spm_model_prefix)
text_feature = TextFeaturizer( text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
self.config.model.input_dim = self.config.collator.feat_dim self.config.model.input_dim = self.config.collator.feat_dim
self.config.model.output_dim = text_feature.vocab_size self.config.model.output_dim = text_feature.vocab_size
@ -279,7 +279,7 @@ class ASRExecutor(BaseExecutor):
audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
text_feature = TextFeaturizer( text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
self._inputs["audio"] = audio self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len self._inputs["audio_len"] = audio_len
@ -295,7 +295,7 @@ class ASRExecutor(BaseExecutor):
""" """
text_feature = TextFeaturizer( text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
cfg = self.config.decoding cfg = self.config.decoding
audio = self._inputs["audio"] audio = self._inputs["audio"]
@ -321,13 +321,7 @@ class ASRExecutor(BaseExecutor):
audio_len, audio_len,
text_feature=text_feature, text_feature=text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size, beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight, ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,

@ -180,7 +180,7 @@ class STExecutor(BaseExecutor):
res_path, self.config.collator.spm_model_prefix) res_path, self.config.collator.spm_model_prefix)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
self.config.model.input_dim = self.config.collator.feat_dim self.config.model.input_dim = self.config.collator.feat_dim
self.config.model.output_dim = self.text_feature.vocab_size self.config.model.output_dim = self.text_feature.vocab_size
@ -292,14 +292,7 @@ class STExecutor(BaseExecutor):
audio_len, audio_len,
text_feature=self.text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=None,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size, beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
word_reward=cfg.word_reward, word_reward=cfg.word_reward,
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,

@ -41,7 +41,7 @@ class DeepSpeech2Tester_hub():
self.audio_file = args.audio_file self.audio_file = args.audio_file
self.collate_fn_test = SpeechCollator.from_config(config) self.collate_fn_test = SpeechCollator.from_config(config)
self._text_featurizer = TextFeaturizer( self._text_featurizer = TextFeaturizer(
unit_type=config.collator.unit_type, vocab_filepath=None) unit_type=config.collator.unit_type, vocab=None)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
result_transcripts = self.model.decode( result_transcripts = self.model.decode(

@ -286,7 +286,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
self._text_featurizer = TextFeaturizer( self._text_featurizer = TextFeaturizer(
unit_type=config.collator.unit_type, vocab_filepath=None) unit_type=config.collator.unit_type, vocab=None)
def ordid2token(self, texts, texts_len): def ordid2token(self, texts, texts_len):
""" ord() id to chr() chr """ """ ord() id to chr() chr """

@ -44,7 +44,7 @@ class U2Infer():
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=config.collator.unit_type, unit_type=config.collator.unit_type,
vocab_filepath=config.collator.vocab_filepath, vocab=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix) spm_model_prefix=config.collator.spm_model_prefix)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
@ -91,13 +91,7 @@ class U2Infer():
ilen, ilen,
text_feature=self.text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size, beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight, ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,

@ -437,7 +437,7 @@ class U2Tester(U2Trainer):
super().__init__(config, args) super().__init__(config, args)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list self.vocab_list = self.text_feature.vocab_list
@ -469,13 +469,7 @@ class U2Tester(U2Trainer):
audio_len, audio_len,
text_feature=self.text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size, beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight, ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,

@ -393,7 +393,7 @@ class U2Tester(U2Trainer):
super().__init__(config, args) super().__init__(config, args)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list self.vocab_list = self.text_feature.vocab_list
@ -425,13 +425,7 @@ class U2Tester(U2Trainer):
audio_len, audio_len,
text_feature=self.text_feature, text_feature=self.text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size, beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight, ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,

@ -437,14 +437,7 @@ class U2STTester(U2STTrainer):
audio_len, audio_len,
text_feature=text_feature, text_feature=text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size, beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
word_reward=cfg.word_reward, word_reward=cfg.word_reward,
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,
@ -475,14 +468,7 @@ class U2STTester(U2STTrainer):
audio_len, audio_len,
text_feature=text_feature, text_feature=text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size, beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight,
word_reward=cfg.word_reward, word_reward=cfg.word_reward,
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,

@ -55,7 +55,7 @@ class SpeechFeaturizer():
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=unit_type, unit_type=unit_type,
vocab_filepath=vocab_filepath, vocab=vocab_filepath,
spm_model_prefix=spm_model_prefix, spm_model_prefix=spm_model_prefix,
maskctc=maskctc) maskctc=maskctc)
self.vocab_size = self.text_feature.vocab_size self.vocab_size = self.text_feature.vocab_size

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains the text featurizer class.""" """Contains the text featurizer class."""
from pprint import pformat from pprint import pformat
from typing import Union
import sentencepiece as spm import sentencepiece as spm
@ -31,11 +32,7 @@ __all__ = ["TextFeaturizer"]
class TextFeaturizer(): class TextFeaturizer():
def __init__(self, def __init__(self, unit_type, vocab, spm_model_prefix=None, maskctc=False):
unit_type,
vocab_filepath,
spm_model_prefix=None,
maskctc=False):
"""Text featurizer, for processing or extracting features from text. """Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into Currently, it supports char/word/sentence-piece level tokenizing and conversion into
@ -44,7 +41,7 @@ class TextFeaturizer():
Args: Args:
unit_type (str): unit type, e.g. char, word, spm unit_type (str): unit type, e.g. char, word, spm
vocab_filepath (str): Filepath to load vocabulary for token indices conversion. vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list.
spm_model_prefix (str, optional): spm model prefix. Defaults to None. spm_model_prefix (str, optional): spm model prefix. Defaults to None.
""" """
assert unit_type in ('char', 'spm', 'word') assert unit_type in ('char', 'spm', 'word')
@ -52,12 +49,12 @@ class TextFeaturizer():
self.unk = UNK self.unk = UNK
self.maskctc = maskctc self.maskctc = maskctc
if vocab_filepath: if vocab:
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file( self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file(
vocab_filepath, maskctc) vocab, maskctc)
self.vocab_size = len(self.vocab_list) self.vocab_size = len(self.vocab_list)
else: else:
logger.warning("TextFeaturizer: not have vocab file.") logger.warning("TextFeaturizer: not have vocab file or vocab list.")
if unit_type == 'spm': if unit_type == 'spm':
spm_model = spm_model_prefix + '.model' spm_model = spm_model_prefix + '.model'
@ -207,9 +204,13 @@ class TextFeaturizer():
return decode(tokens) return decode(tokens)
def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool): def _load_vocabulary_from_file(self, vocab: Union[str, list],
maskctc: bool):
"""Load vocabulary from file.""" """Load vocabulary from file."""
vocab_list = load_dict(vocab_filepath, maskctc) if isinstance(vocab, list):
vocab_list = vocab
else:
vocab_list = load_dict(vocab, maskctc)
assert vocab_list is not None assert vocab_list is not None
logger.debug(f"Vocab: {pformat(vocab_list)}") logger.debug(f"Vocab: {pformat(vocab_list)}")

@ -42,7 +42,7 @@ class TextCollatorSpm():
assert (vocab_filepath is not None) assert (vocab_filepath is not None)
self.text_featurizer = TextFeaturizer( self.text_featurizer = TextFeaturizer(
unit_type=unit_type, unit_type=unit_type,
vocab_filepath=vocab_filepath, vocab=vocab_filepath,
spm_model_prefix=spm_model_prefix) spm_model_prefix=spm_model_prefix)
self.eos_id = self.text_featurizer.eos_id self.eos_id = self.text_featurizer.eos_id
self.blank_id = self.text_featurizer.blank_id self.blank_id = self.text_featurizer.blank_id

@ -717,13 +717,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
feats_lengths: paddle.Tensor, feats_lengths: paddle.Tensor,
text_feature: Dict[str, int], text_feature: Dict[str, int],
decoding_method: str, decoding_method: str,
lang_model_path: str,
beam_alpha: float,
beam_beta: float,
beam_size: int, beam_size: int,
cutoff_prob: float,
cutoff_top_n: int,
num_processes: int,
ctc_weight: float=0.0, ctc_weight: float=0.0,
decoding_chunk_size: int=-1, decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1, num_decoding_left_chunks: int=-1,
@ -737,13 +731,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
decoding_method (str): decoding mode, e.g. decoding_method (str): decoding mode, e.g.
'attention', 'ctc_greedy_search', 'attention', 'ctc_greedy_search',
'ctc_prefix_beam_search', 'attention_rescoring' 'ctc_prefix_beam_search', 'attention_rescoring'
lang_model_path (str): lm path.
beam_alpha (float): lm weight.
beam_beta (float): length penalty.
beam_size (int): beam size for search beam_size (int): beam size for search
cutoff_prob (float): for prune.
cutoff_top_n (int): for prune.
num_processes (int):
ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0. ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0.
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1. decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
<0: for decoding, use full chunk. <0: for decoding, use full chunk.
@ -839,12 +827,13 @@ class U2Model(U2DecodeModel):
def __init__(self, configs: dict): def __init__(self, configs: dict):
vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs) vocab_size, encoder, decoder, ctc = U2Model._init_from_config(configs)
model_conf = configs.get('model_conf', dict())
super().__init__( super().__init__(
vocab_size=vocab_size, vocab_size=vocab_size,
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
ctc=ctc, ctc=ctc,
**configs['model_conf']) **model_conf)
@classmethod @classmethod
def _init_from_config(cls, configs: dict): def _init_from_config(cls, configs: dict):
@ -893,7 +882,7 @@ class U2Model(U2DecodeModel):
**configs['decoder_conf']) **configs['decoder_conf'])
# ctc decoder and ctc loss # ctc decoder and ctc loss
model_conf = configs['model_conf'] model_conf = configs.get('model_conf', dict())
dropout_rate = model_conf.get('ctc_dropout_rate', 0.0) dropout_rate = model_conf.get('ctc_dropout_rate', 0.0)
grad_norm_type = model_conf.get('ctc_grad_norm_type', None) grad_norm_type = model_conf.get('ctc_grad_norm_type', None)
ctc = CTCDecoder( ctc = CTCDecoder(

@ -522,14 +522,7 @@ class U2STBaseModel(nn.Layer):
feats_lengths: paddle.Tensor, feats_lengths: paddle.Tensor,
text_feature: Dict[str, int], text_feature: Dict[str, int],
decoding_method: str, decoding_method: str,
lang_model_path: str,
beam_alpha: float,
beam_beta: float,
beam_size: int, beam_size: int,
cutoff_prob: float,
cutoff_top_n: int,
num_processes: int,
ctc_weight: float=0.0,
word_reward: float=0.0, word_reward: float=0.0,
decoding_chunk_size: int=-1, decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1, num_decoding_left_chunks: int=-1,
@ -543,14 +536,7 @@ class U2STBaseModel(nn.Layer):
decoding_method (str): decoding mode, e.g. decoding_method (str): decoding mode, e.g.
'fullsentence', 'fullsentence',
'simultaneous' 'simultaneous'
lang_model_path (str): lm path.
beam_alpha (float): lm weight.
beam_beta (float): length penalty.
beam_size (int): beam size for search beam_size (int): beam size for search
cutoff_prob (float): for prune.
cutoff_top_n (int): for prune.
num_processes (int):
ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0.
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1. decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
<0: for decoding, use full chunk. <0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set. >0: for decoding, use fixed chunk size as set.

@ -168,13 +168,17 @@ class GlobalCMVN():
norm_means=True, norm_means=True,
norm_vars=True, norm_vars=True,
std_floor=1.0e-20): std_floor=1.0e-20):
self.cmvn_path = cmvn_path # cmvn_path: Option[str, dict]
cmvn = cmvn_path
self.cmvn = cmvn
self.norm_means = norm_means self.norm_means = norm_means
self.norm_vars = norm_vars self.norm_vars = norm_vars
self.std_floor = std_floor self.std_floor = std_floor
if isinstance(cmvn, dict):
with open(cmvn_path) as f: cmvn_stats = cmvn
cmvn_stats = json.load(f) else:
with open(cmvn) as f:
cmvn_stats = json.load(f)
self.count = cmvn_stats['frame_num'] self.count = cmvn_stats['frame_num']
self.mean = np.array(cmvn_stats['mean_stat']) / self.count self.mean = np.array(cmvn_stats['mean_stat']) / self.count
self.square_sums = np.array(cmvn_stats['var_stat']) self.square_sums = np.array(cmvn_stats['var_stat'])
@ -183,7 +187,7 @@ class GlobalCMVN():
def __repr__(self): def __repr__(self):
return f"""{self.__class__.__name__}( return f"""{self.__class__.__name__}(
cmvn_path={self.cmvn_path}, cmvn_path={self.cmvn},
norm_means={self.norm_means}, norm_means={self.norm_means},
norm_vars={self.norm_vars},)""" norm_vars={self.norm_vars},)"""

Loading…
Cancel
Save