From 30f6d7bcbfb580392af4d3381b179929493e811c Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Sun, 9 Jan 2022 12:18:01 +0000 Subject: [PATCH 1/6] fix config, test=asr --- examples/aishell/asr0/conf/deepspeech2.yaml | 2 +- examples/aishell/asr0/conf/deepspeech2_online.yaml | 2 +- examples/tiny/asr0/conf/deepspeech2.yaml | 2 +- examples/tiny/asr0/conf/deepspeech2_online.yaml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/aishell/asr0/conf/deepspeech2.yaml b/examples/aishell/asr0/conf/deepspeech2.yaml index 1dc8581e0..ec9e02b66 100644 --- a/examples/aishell/asr0/conf/deepspeech2.yaml +++ b/examples/aishell/asr0/conf/deepspeech2.yaml @@ -23,7 +23,7 @@ augmentation_config: conf/augmentation.json random_seed: 0 spm_model_prefix: spectrum_type: linear -feat_dim: +feat_dim: 161 delta_delta: False stride_ms: 10.0 window_ms: 20.0 diff --git a/examples/aishell/asr0/conf/deepspeech2_online.yaml b/examples/aishell/asr0/conf/deepspeech2_online.yaml index c49973a26..05594e2d7 100644 --- a/examples/aishell/asr0/conf/deepspeech2_online.yaml +++ b/examples/aishell/asr0/conf/deepspeech2_online.yaml @@ -23,7 +23,7 @@ augmentation_config: conf/augmentation.json random_seed: 0 spm_model_prefix: spectrum_type: linear #linear, mfcc, fbank -feat_dim: +feat_dim: 161 delta_delta: False stride_ms: 10.0 window_ms: 20.0 diff --git a/examples/tiny/asr0/conf/deepspeech2.yaml b/examples/tiny/asr0/conf/deepspeech2.yaml index a16a79d3a..2cc4483eb 100644 --- a/examples/tiny/asr0/conf/deepspeech2.yaml +++ b/examples/tiny/asr0/conf/deepspeech2.yaml @@ -23,7 +23,7 @@ augmentation_config: conf/augmentation.json random_seed: 0 spm_model_prefix: spectrum_type: linear -feat_dim: +feat_dim: 161 delta_delta: False stride_ms: 10.0 window_ms: 20.0 diff --git a/examples/tiny/asr0/conf/deepspeech2_online.yaml b/examples/tiny/asr0/conf/deepspeech2_online.yaml index 5458cfb30..3bd4f6350 100644 --- a/examples/tiny/asr0/conf/deepspeech2_online.yaml +++ b/examples/tiny/asr0/conf/deepspeech2_online.yaml @@ -23,7 +23,7 @@ augmentation_config: conf/augmentation.json random_seed: 0 spm_model_prefix: spectrum_type: linear -feat_dim: +feat_dim: 161 delta_delta: False stride_ms: 10.0 window_ms: 20.0 From 35ca7f6e984699d88edf4c96017294d4284f1623 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Mon, 10 Jan 2022 02:40:46 +0000 Subject: [PATCH 2/6] fix config, test=doc_fix --- examples/aishell/asr0/conf/deepspeech2.yaml | 4 ++-- examples/aishell/asr0/conf/deepspeech2_online.yaml | 4 ++-- examples/librispeech/asr0/conf/deepspeech2.yaml | 4 ++-- examples/librispeech/asr0/conf/deepspeech2_online.yaml | 4 ++-- examples/tiny/asr0/conf/deepspeech2.yaml | 4 ++-- examples/tiny/asr0/conf/deepspeech2_online.yaml | 4 ++-- examples/wenetspeech/asr1/local/data.sh | 2 +- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/aishell/asr0/conf/deepspeech2.yaml b/examples/aishell/asr0/conf/deepspeech2.yaml index ec9e02b66..fb6998647 100644 --- a/examples/aishell/asr0/conf/deepspeech2.yaml +++ b/examples/aishell/asr0/conf/deepspeech2.yaml @@ -54,9 +54,9 @@ ctc_grad_norm_type: instance ########################################### n_epoch: 80 accum_grad: 1 -lr: 2e-3 +lr: 2.0e-3 lr_decay: 0.83 -weight_decay: 1e-06 +weight_decay: 1.0e-6 global_grad_clip: 3.0 log_interval: 100 checkpoint: diff --git a/examples/aishell/asr0/conf/deepspeech2_online.yaml b/examples/aishell/asr0/conf/deepspeech2_online.yaml index 05594e2d7..ef01ac595 100644 --- a/examples/aishell/asr0/conf/deepspeech2_online.yaml +++ b/examples/aishell/asr0/conf/deepspeech2_online.yaml @@ -56,9 +56,9 @@ blank_id: 0 ########################################### n_epoch: 65 accum_grad: 1 -lr: 5e-4 +lr: 5.0e-4 lr_decay: 0.93 -weight_decay: 1e-06 +weight_decay: 1.0e-6 global_grad_clip: 3.0 log_interval: 100 checkpoint: diff --git a/examples/librispeech/asr0/conf/deepspeech2.yaml b/examples/librispeech/asr0/conf/deepspeech2.yaml index 0b0a1550d..0307b9f39 100644 --- a/examples/librispeech/asr0/conf/deepspeech2.yaml +++ b/examples/librispeech/asr0/conf/deepspeech2.yaml @@ -55,9 +55,9 @@ blank_id: 0 ########################################### n_epoch: 50 accum_grad: 1 -lr: 1e-3 +lr: 1.0e-3 lr_decay: 0.83 -weight_decay: 1e-06 +weight_decay: 1.0e-6 global_grad_clip: 5.0 log_interval: 100 checkpoint: diff --git a/examples/librispeech/asr0/conf/deepspeech2_online.yaml b/examples/librispeech/asr0/conf/deepspeech2_online.yaml index 8bd5a6727..a0d2bcfe2 100644 --- a/examples/librispeech/asr0/conf/deepspeech2_online.yaml +++ b/examples/librispeech/asr0/conf/deepspeech2_online.yaml @@ -57,9 +57,9 @@ blank_id: 0 ########################################### n_epoch: 50 accum_grad: 4 -lr: 1e-3 +lr: 1.0e-3 lr_decay: 0.83 -weight_decay: 1e-06 +weight_decay: 1.0e-6 global_grad_clip: 5.0 log_interval: 100 checkpoint: diff --git a/examples/tiny/asr0/conf/deepspeech2.yaml b/examples/tiny/asr0/conf/deepspeech2.yaml index 2cc4483eb..64d432e26 100644 --- a/examples/tiny/asr0/conf/deepspeech2.yaml +++ b/examples/tiny/asr0/conf/deepspeech2.yaml @@ -55,9 +55,9 @@ blank_id: 0 ########################################### n_epoch: 5 accum_grad: 1 -lr: 1e-5 +lr: 1.0e-5 lr_decay: 0.8 -weight_decay: 1e-06 +weight_decay: 1.0e-6 global_grad_clip: 5.0 log_interval: 1 checkpoint: diff --git a/examples/tiny/asr0/conf/deepspeech2_online.yaml b/examples/tiny/asr0/conf/deepspeech2_online.yaml index 3bd4f6350..74a4dc814 100644 --- a/examples/tiny/asr0/conf/deepspeech2_online.yaml +++ b/examples/tiny/asr0/conf/deepspeech2_online.yaml @@ -57,9 +57,9 @@ blank_id: 0 ########################################### n_epoch: 5 accum_grad: 1 -lr: 1e-5 +lr: 1.0e-5 lr_decay: 1.0 -weight_decay: 1e-06 +weight_decay: 1.0e-6 global_grad_clip: 5.0 log_interval: 1 checkpoint: diff --git a/examples/wenetspeech/asr1/local/data.sh b/examples/wenetspeech/asr1/local/data.sh index 7dd478d19..d216dd84a 100755 --- a/examples/wenetspeech/asr1/local/data.sh +++ b/examples/wenetspeech/asr1/local/data.sh @@ -96,7 +96,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then python3 ${MAIN_ROOT}/utils/build_vocab.py \ --unit_type="char" \ --count_threshold=0 \ - --vocab_path="data/vocab.txt" \ + --vocab_path="data/lang_char/vocab.txt" \ --manifest_paths "data/manifest.train.raw" if [ $? -ne 0 ]; then From 5c9e4caa7b8603f3b9ad0fe3d81cdfa99043c115 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Mon, 10 Jan 2022 09:30:10 +0000 Subject: [PATCH 3/6] add en and decode_method for cli/asr, test=asr --- paddlespeech/cli/asr/infer.py | 91 ++++++------- paddlespeech/s2t/frontend/normalizer.py | 13 +- utils/generate_infer_yaml.py | 174 ++++++++++++++++++++++++ 3 files changed, 225 insertions(+), 53 deletions(-) create mode 100644 utils/generate_infer_yaml.py diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 8de964768..53379ed71 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -46,19 +46,29 @@ pretrained_models = { # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" "conformer_wenetspeech-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz', 'md5': - '54e7a558a6e020c2f5fb224874943f97', + 'b9afd8285ff5b2596bf96afab656b02f', 'cfg_path': - 'conf/conformer.yaml', + 'conf/conformer_infer.yaml', 'ckpt_path': 'exp/conformer/checkpoints/wenetspeech', }, + "transformer_librispeech-en-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz', + 'md5': + 'c95b9997f5f81478b32879a38532913d', + 'cfg_path': + 'conf/transformer_infer.yaml', + 'ckpt_path': + 'exp/transformer/checkpoints/avg_10', + }, } model_alias = { - "ds2_offline": "paddlespeech.s2t.models.ds2:DeepSpeech2Model", - "ds2_online": "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", + "deepspeech2offline": "paddlespeech.s2t.models.ds2:DeepSpeech2Model", + "deepspeech2online": "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "conformer": "paddlespeech.s2t.models.u2:U2Model", "transformer": "paddlespeech.s2t.models.u2:U2Model", "wenetspeech": "paddlespeech.s2t.models.u2:U2Model", @@ -85,7 +95,7 @@ class ASRExecutor(BaseExecutor): '--lang', type=str, default='zh', - help='Choose model language. zh or en') + help='Choose model language. zh or en, zh:[conformer_wenetspeech-zh-16k], en:[transformer_librispeech-en-16k]') self.parser.add_argument( "--sample_rate", type=int, @@ -97,6 +107,12 @@ class ASRExecutor(BaseExecutor): type=str, default=None, help='Config of asr task. Use deault config when it is None.') + self.parser.add_argument( + '--decode_method', + type=str, + default='attention_rescoring', + choices=['ctc_greedy_search', 'ctc_prefix_beam_search', 'attention', 'attention_rescoring'], + help='only support transformer and conformer model') self.parser.add_argument( '--ckpt_path', type=str, @@ -136,6 +152,7 @@ class ASRExecutor(BaseExecutor): lang: str='zh', sample_rate: int=16000, cfg_path: Optional[os.PathLike]=None, + decode_method: str='attention_rescoring', ckpt_path: Optional[os.PathLike]=None): """ Init model and other resources from a specific path. @@ -165,45 +182,30 @@ class ASRExecutor(BaseExecutor): #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) - self.config.decoding.decoding_method = "attention_rescoring" with UpdateConfig(self.config): - if "ds2_online" in model_type or "ds2_offline" in model_type: + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: from paddlespeech.s2t.io.collator import SpeechCollator - self.config.collator.vocab_filepath = os.path.join( - res_path, self.config.collator.vocab_filepath) - self.config.collator.mean_std_filepath = os.path.join( - res_path, self.config.collator.cmvn_path) + self.vocab = self.config.vocab_filepath + self.config.decode.lang_model_path = os.path.join(res_path, self.config.decode.lang_model_path) self.collate_fn_test = SpeechCollator.from_config(self.config) self.text_feature = TextFeaturizer( - unit_type=self.config.collator.unit_type, - vocab=self.config.collator.vocab_filepath, - spm_model_prefix=self.config.collator.spm_model_prefix) - self.config.model.input_dim = self.collate_fn_test.feature_size - self.config.model.output_dim = self.text_feature.vocab_size + unit_type=self.config.unit_type, + vocab=self.vocab) elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: - self.config.collator.vocab_filepath = os.path.join( - res_path, self.config.collator.vocab_filepath) - self.config.collator.augmentation_config = os.path.join( - res_path, self.config.collator.augmentation_config) - self.config.collator.spm_model_prefix = os.path.join( - res_path, self.config.collator.spm_model_prefix) + self.config.spm_model_prefix = os.path.join(self.res_path, self.config.spm_model_prefix) self.text_feature = TextFeaturizer( - unit_type=self.config.collator.unit_type, - vocab=self.config.collator.vocab_filepath, - spm_model_prefix=self.config.collator.spm_model_prefix) - self.config.model.input_dim = self.config.collator.feat_dim - self.config.model.output_dim = self.text_feature.vocab_size + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + self.config.decode.decoding_method = decode_method else: raise Exception("wrong type") - # Enter the path of model root - model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} model_class = dynamic_import(model_name, model_alias) - model_conf = self.config.model - logger.info(model_conf) + model_conf = self.config model = model_class.from_config(model_conf) self.model = model self.model.eval() @@ -222,7 +224,7 @@ class ASRExecutor(BaseExecutor): logger.info("Preprocess audio_file:" + audio_file) # Get the object for feature extraction - if "ds2_online" in model_type or "ds2_offline" in model_type: + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: audio, _ = self.collate_fn_test.process_utterance( audio_file=audio_file, transcript=" ") audio_len = audio.shape[0] @@ -236,18 +238,7 @@ class ASRExecutor(BaseExecutor): elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: logger.info("get the preprocess conf") - preprocess_conf_file = self.config.collator.augmentation_config - # redirect the cmvn path - with io.open(preprocess_conf_file, encoding="utf-8") as f: - preprocess_conf = yaml.safe_load(f) - for idx, process in enumerate(preprocess_conf["process"]): - if process['type'] == "cmvn_json": - preprocess_conf["process"][idx][ - "cmvn_path"] = os.path.join( - self.res_path, - preprocess_conf["process"][idx]["cmvn_path"]) - break - logger.info(preprocess_conf) + preprocess_conf = self.config.preprocess_config preprocess_args = {"train": False} preprocessing = Transformation(preprocess_conf) logger.info("read the audio file") @@ -289,10 +280,10 @@ class ASRExecutor(BaseExecutor): Model inference and result stored in self.output. """ - cfg = self.config.decoding + cfg = self.config.decode audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] - if "ds2_online" in model_type or "ds2_offline" in model_type: + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: result_transcripts = self.model.decode( audio, audio_len, @@ -414,12 +405,13 @@ class ASRExecutor(BaseExecutor): config = parser_args.config ckpt_path = parser_args.ckpt_path audio_file = parser_args.input + decode_method = parser_args.decode_method force_yes = parser_args.yes device = parser_args.device try: res = self(audio_file, model, lang, sample_rate, config, ckpt_path, - force_yes, device) + decode_method, force_yes, device) logger.info('ASR Result: {}'.format(res)) return True except Exception as e: @@ -434,6 +426,7 @@ class ASRExecutor(BaseExecutor): sample_rate: int=16000, config: os.PathLike=None, ckpt_path: os.PathLike=None, + decode_method: str='attention_rescoring', force_yes: bool=False, device=paddle.get_device()): """ @@ -442,7 +435,7 @@ class ASRExecutor(BaseExecutor): audio_file = os.path.abspath(audio_file) self._check(audio_file, sample_rate, force_yes) paddle.set_device(device) - self._init_from_path(model, lang, sample_rate, config, ckpt_path) + self._init_from_path(model, lang, sample_rate, config, decode_method, ckpt_path) self.preprocess(model, audio_file) self.infer(model) res = self.postprocess() # Retrieve result of asr. diff --git a/paddlespeech/s2t/frontend/normalizer.py b/paddlespeech/s2t/frontend/normalizer.py index 017851e63..b596b2ab0 100644 --- a/paddlespeech/s2t/frontend/normalizer.py +++ b/paddlespeech/s2t/frontend/normalizer.py @@ -117,7 +117,8 @@ class FeatureNormalizer(object): self._compute_mean_std(manifest_path, featurize_func, num_samples, num_workers) else: - self._read_mean_std_from_file(mean_std_filepath) + mean_std = mean_std_filepath + self._read_mean_std_from_file(mean_std) def apply(self, features): """Normalize features to be of zero mean and unit stddev. @@ -131,10 +132,14 @@ class FeatureNormalizer(object): """ return (features - self._mean) * self._istd - def _read_mean_std_from_file(self, filepath, eps=1e-20): + def _read_mean_std_from_file(self, mean_std, eps=1e-20): """Load mean and std from file.""" - filetype = filepath.split(".")[-1] - mean, istd = load_cmvn(filepath, filetype=filetype) + if isinstance(mean_std, list): + mean = mean_std[0]['cmvn_stats']['mean'] + istd = mean_std[0]['cmvn_stats']['istd'] + else: + filetype = mean_std.split(".")[-1] + mean, istd = load_cmvn(mean_std, filetype=filetype) self._mean = np.expand_dims(mean, axis=0) self._istd = np.expand_dims(istd, axis=0) diff --git a/utils/generate_infer_yaml.py b/utils/generate_infer_yaml.py new file mode 100644 index 000000000..5eed738c6 --- /dev/null +++ b/utils/generate_infer_yaml.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +''' + Merge training configs into a single inference config. +''' + +import yaml +import json +import os +import argparse +import math +from yacs.config import CfgNode + +from paddlespeech.s2t.frontend.utility import load_dict +from contextlib import redirect_stdout + + +def save(save_path, config): + with open(save_path, 'w') as fp: + with redirect_stdout(fp): + print(config.dump()) + + +def load(save_path): + config = CfgNode(new_allowed=True) + config.merge_from_file(save_path) + return config + +def load_json(json_path): + with open(json_path) as f: + json_content = json.load(f) + return json_content + +def remove_config_part(config, key_list): + if len(key_list) == 0: + return + for i in range(len(key_list) -1): + config = config[key_list[i]] + config.pop(key_list[-1]) + +def load_cmvn_from_json(cmvn_stats): + means = cmvn_stats['mean_stat'] + variance = cmvn_stats['var_stat'] + count = cmvn_stats['frame_num'] + for i in range(len(means)): + means[i] /= count + variance[i] = variance[i] / count - means[i] * means[i] + if variance[i] < 1.0e-20: + variance[i] = 1.0e-20 + variance[i] = 1.0 / math.sqrt(variance[i]) + cmvn_stats = {"mean":means, "istd":variance} + return cmvn_stats + +def merge_configs( + conf_path = "conf/conformer.yaml", + preprocess_path = "conf/preprocess.yaml", + decode_path = "conf/tuning/decode.yaml", + vocab_path = "data/vocab.txt", + cmvn_path = "data/mean_std.json", + save_path = "conf/conformer_infer.yaml", + ): + + # Load the configs + config = load(conf_path) + decode_config = load(decode_path) + vocab_list = load_dict(vocab_path) + cmvn_stats = load_json(cmvn_path) + if os.path.exists(preprocess_path): + preprocess_config = load(preprocess_path) + for idx, process in enumerate(preprocess_config["process"]): + if process['type'] == "cmvn_json": + preprocess_config["process"][idx][ + "cmvn_path"] = cmvn_stats + break + + config.preprocess_config = preprocess_config + else: + cmvn_stats = load_cmvn_from_json(cmvn_stats) + config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}] + config.augmentation_config = '' + + # Updata the config + config.vocab_filepath = vocab_list + config.input_dim = config.feat_dim + config.output_dim = len(config.vocab_filepath) + config.decode = decode_config + # Remove some parts of the config + + if os.path.exists(preprocess_path): + remove_list = ["train_manifest", + "dev_manifest", + "test_manifest", + "n_epoch", + "accum_grad", + "global_grad_clip", + "optim", + "optim_conf", + "scheduler", + "scheduler_conf", + "log_interval", + "checkpoint", + "shuffle_method", + "weight_decay", + "ctc_grad_norm_type", + "minibatches", + "batch_bins", + "batch_count", + "batch_frames_in", + "batch_frames_inout", + "batch_frames_out", + "sortagrad", + "feat_dim", + "stride_ms", + "window_ms", + "batch_size", + "maxlen_in", + "maxlen_out", + ] + else: + remove_list = ["train_manifest", + "dev_manifest", + "test_manifest", + "n_epoch", + "accum_grad", + "global_grad_clip", + "log_interval", + "checkpoint", + "lr", + "lr_decay", + "batch_size", + "shuffle_method", + "weight_decay", + "sortagrad", + "num_workers", + ] + + for item in remove_list: + try: + remove_config_part(config, [item]) + except: + print ( item + " " +"can not be removed") + + # Save the config + save(save_path, config) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog='Config merge', add_help=True) + parser.add_argument( + '--cfg_pth', type=str, default = 'conf/transformer.yaml', help='origin config file') + parser.add_argument( + '--pre_pth', type=str, default= "conf/preprocess.yaml", help='') + parser.add_argument( + '--dcd_pth', type=str, default= "conf/tuninig/decode.yaml", help='') + parser.add_argument( + '--vb_pth', type=str, default= "data/lang_char/vocab.txt", help='') + parser.add_argument( + '--cmvn_pth', type=str, default= "data/mean_std.json", help='') + parser.add_argument( + '--save_pth', type=str, default= "conf/transformer_infer.yaml", help='') + parser_args = parser.parse_args() + + merge_configs( + conf_path = parser_args.cfg_pth, + preprocess_path = parser_args.pre_pth, + vocab_path = parser_args.vb_pth, + cmvn_path = parser_args.cmvn_pth, + save_path = parser_args.save_pth, + ) + + From d902f3879119695d2d2835eccc04a1e4fb4085ee Mon Sep 17 00:00:00 2001 From: Jackwaterveg <87408988+Jackwaterveg@users.noreply.github.com> Date: Mon, 10 Jan 2022 17:56:29 +0800 Subject: [PATCH 4/6] test=asr --- CHANGELOG.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4dc68c6ff..0374659b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,2 +1,11 @@ # Changelog + +Date: 2022-1-10, Author: Jackwaterveg. +Add features to: CLI: + - Support English (librispeech/asr1/transformer). + - Support choosing `decode_method` for conformer and transformer models. + - Refactor the config, using the unified config. + - Pr_link: https://github.com/PaddlePaddle/PaddleSpeech/pull/1297 + +*** From 11ba35d08be6f9826e88f3f54f403cd3e52a5fd3 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Mon, 10 Jan 2022 12:08:31 +0000 Subject: [PATCH 5/6] fix, test=doc_fix --- CHANGELOG.md | 2 +- paddlespeech/cli/asr/infer.py | 10 +++++----- utils/generate_infer_yaml.py | 10 +++++++--- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0374659b2..5ffe80984 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,6 @@ Add features to: CLI: - Support English (librispeech/asr1/transformer). - Support choosing `decode_method` for conformer and transformer models. - Refactor the config, using the unified config. - - Pr_link: https://github.com/PaddlePaddle/PaddleSpeech/pull/1297 + - PRLink: https://github.com/PaddlePaddle/PaddleSpeech/pull/1297 *** diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 53379ed71..aa4e31d9e 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -48,9 +48,9 @@ pretrained_models = { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz', 'md5': - 'b9afd8285ff5b2596bf96afab656b02f', + '76cb19ed857e6623856b7cd7ebbfeda4', 'cfg_path': - 'conf/conformer_infer.yaml', + 'model.yaml', 'ckpt_path': 'exp/conformer/checkpoints/wenetspeech', }, @@ -58,9 +58,9 @@ pretrained_models = { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz', 'md5': - 'c95b9997f5f81478b32879a38532913d', + '2c667da24922aad391eacafe37bc1660', 'cfg_path': - 'conf/transformer_infer.yaml', + 'model.yaml', 'ckpt_path': 'exp/transformer/checkpoints/avg_10', }, @@ -176,7 +176,7 @@ class ASRExecutor(BaseExecutor): else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") - res_path = os.path.dirname( + self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) #Init body. diff --git a/utils/generate_infer_yaml.py b/utils/generate_infer_yaml.py index 5eed738c6..d2a6777c7 100644 --- a/utils/generate_infer_yaml.py +++ b/utils/generate_infer_yaml.py @@ -3,6 +3,8 @@ ''' Merge training configs into a single inference config. + The single inference config is for CLI, which only takes a single config to do inferencing. + The trainig configs includes: model config, preprocess config, decode config, vocab file and cmvn file. ''' import yaml @@ -88,7 +90,7 @@ def merge_configs( # Remove some parts of the config if os.path.exists(preprocess_path): - remove_list = ["train_manifest", + remove_train_list = ["train_manifest", "dev_manifest", "test_manifest", "n_epoch", @@ -104,6 +106,7 @@ def merge_configs( "weight_decay", "ctc_grad_norm_type", "minibatches", + "subsampling_factor", "batch_bins", "batch_count", "batch_frames_in", @@ -118,7 +121,7 @@ def merge_configs( "maxlen_out", ] else: - remove_list = ["train_manifest", + remove_train_list = ["train_manifest", "dev_manifest", "test_manifest", "n_epoch", @@ -135,7 +138,7 @@ def merge_configs( "num_workers", ] - for item in remove_list: + for item in remove_train_list: try: remove_config_part(config, [item]) except: @@ -165,6 +168,7 @@ if __name__ == "__main__": merge_configs( conf_path = parser_args.cfg_pth, + decode_path = parser_args.dcd_pth, preprocess_path = parser_args.pre_pth, vocab_path = parser_args.vb_pth, cmvn_path = parser_args.cmvn_pth, From fe1dc9d2111695d9159be9e89a5239e09f084aaf Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Tue, 11 Jan 2022 11:42:14 +0000 Subject: [PATCH 6/6] refactor the cli/st, test=st --- paddlespeech/cli/st/infer.py | 35 +++++++++++++++-------------------- utils/generate_infer_yaml.py | 33 +++++++++++++++++++-------------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/paddlespeech/cli/st/infer.py b/paddlespeech/cli/st/infer.py index d6bd6304d..1276424c5 100644 --- a/paddlespeech/cli/st/infer.py +++ b/paddlespeech/cli/st/infer.py @@ -40,11 +40,11 @@ __all__ = ["STExecutor"] pretrained_models = { "fat_st_ted-en-zh": { "url": - "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/fat_st_ted-en-zh.tar.gz", + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz", "md5": - "fa0a7425b91b4f8d259c70b2aca5ae67", + "d62063f35a16d91210a71081bd2dd557", "cfg_path": - "conf/transformer_mtl_noam.yaml", + "model.yaml", "ckpt_path": "exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams", } @@ -170,24 +170,19 @@ class STExecutor(BaseExecutor): #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) - self.config.decoding.decoding_method = "fullsentence" + self.config.decode.decoding_method = "fullsentence" with UpdateConfig(self.config): - self.config.collator.vocab_filepath = os.path.join( - res_path, self.config.collator.vocab_filepath) - self.config.collator.cmvn_path = os.path.join( - res_path, self.config.collator.cmvn_path) - self.config.collator.spm_model_prefix = os.path.join( - res_path, self.config.collator.spm_model_prefix) + self.config.cmvn_path = os.path.join( + res_path, self.config.cmvn_path) + self.config.spm_model_prefix = os.path.join( + res_path, self.config.spm_model_prefix) self.text_feature = TextFeaturizer( - unit_type=self.config.collator.unit_type, - vocab=self.config.collator.vocab_filepath, - spm_model_prefix=self.config.collator.spm_model_prefix) - self.config.model.input_dim = self.config.collator.feat_dim - self.config.model.output_dim = self.text_feature.vocab_size - - model_conf = self.config.model - logger.info(model_conf) + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + + model_conf = self.config model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} model_class = dynamic_import(model_name, model_alias) @@ -218,7 +213,7 @@ class STExecutor(BaseExecutor): logger.info("Preprocess audio_file:" + audio_file) if "fat_st" in model_type: - cmvn = self.config.collator.cmvn_path + cmvn = self.config.cmvn_path utt_name = "_tmp" # Get the object for feature extraction @@ -284,7 +279,7 @@ class STExecutor(BaseExecutor): """ Model inference and result stored in self.output. """ - cfg = self.config.decoding + cfg = self.config.decode audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] if model_type == "fat_st_ted": diff --git a/utils/generate_infer_yaml.py b/utils/generate_infer_yaml.py index d2a6777c7..a2eb28c76 100644 --- a/utils/generate_infer_yaml.py +++ b/utils/generate_infer_yaml.py @@ -67,21 +67,26 @@ def merge_configs( config = load(conf_path) decode_config = load(decode_path) vocab_list = load_dict(vocab_path) - cmvn_stats = load_json(cmvn_path) - if os.path.exists(preprocess_path): - preprocess_config = load(preprocess_path) - for idx, process in enumerate(preprocess_config["process"]): - if process['type'] == "cmvn_json": - preprocess_config["process"][idx][ - "cmvn_path"] = cmvn_stats - break - - config.preprocess_config = preprocess_config - else: - cmvn_stats = load_cmvn_from_json(cmvn_stats) - config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}] - config.augmentation_config = '' + # If use the kaldi feature, do not load the cmvn file + if cmvn_path.split(".")[-1] == 'json': + cmvn_stats = load_json(cmvn_path) + if os.path.exists(preprocess_path): + preprocess_config = load(preprocess_path) + for idx, process in enumerate(preprocess_config["process"]): + if process['type'] == "cmvn_json": + preprocess_config["process"][idx][ + "cmvn_path"] = cmvn_stats + break + + config.preprocess_config = preprocess_config + else: + cmvn_stats = load_cmvn_from_json(cmvn_stats) + config.mean_std_filepath = [{"cmvn_stats":cmvn_stats}] + config.augmentation_config = '' + # the cmvn file is end with .ark + else: + config.cmvn_path = cmvn_path # Updata the config config.vocab_filepath = vocab_list config.input_dim = config.feat_dim