diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index b12b9f6f..49dd7b35 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -29,9 +29,10 @@ from ..download import get_path_from_url from ..executor import BaseExecutor from ..log import logger from ..utils import cli_register -from ..utils import download_and_decompress from ..utils import MODEL_HOME from ..utils import stats_wrapper +from .pretrained_models import model_alias +from .pretrained_models import pretrained_models from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import @@ -39,94 +40,13 @@ from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['ASRExecutor'] -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k". - # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: - # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" - "conformer_wenetspeech-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz', - 'md5': - '76cb19ed857e6623856b7cd7ebbfeda4', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/conformer/checkpoints/wenetspeech', - }, - "transformer_librispeech-en-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz', - 'md5': - '2c667da24922aad391eacafe37bc1660', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/transformer/checkpoints/avg_10', - }, - "deepspeech2offline_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', - 'md5': - '932c3593d62fe5c741b59b31318aa314', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2/checkpoints/avg_1', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "deepspeech2online_aishell-zh-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz', - 'md5': - '23e16c69730a1cb5d735c98c83c21e16', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2_online/checkpoints/avg_1', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', - 'lm_md5': - '29e02312deb2e59b3c8686c7966d4fe3' - }, - "deepspeech2offline_librispeech-en-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz', - 'md5': - 'f5666c81ad015c8de03aac2bc92e5762', - 'cfg_path': - 'model.yaml', - 'ckpt_path': - 'exp/deepspeech2/checkpoints/avg_1', - 'lm_url': - 'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm', - 'lm_md5': - '099a601759d467cd0a8523ff939819c5' - }, -} - -model_alias = { - "deepspeech2offline": - "paddlespeech.s2t.models.ds2:DeepSpeech2Model", - "deepspeech2online": - "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", - "conformer": - "paddlespeech.s2t.models.u2:U2Model", - "transformer": - "paddlespeech.s2t.models.u2:U2Model", - "wenetspeech": - "paddlespeech.s2t.models.u2:U2Model", -} - - @cli_register( name='paddlespeech.asr', description='Speech to text infer command.') class ASRExecutor(BaseExecutor): def __init__(self): - super(ASRExecutor, self).__init__() + super().__init__() + self.model_alias = model_alias + self.pretrained_models = pretrained_models self.parser = argparse.ArgumentParser( prog='paddlespeech.asr', add_help=True) @@ -136,7 +56,9 @@ class ASRExecutor(BaseExecutor): '--model', type=str, default='conformer_wenetspeech', - choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()], + choices=[ + tag[:tag.index('-')] for tag in self.pretrained_models.keys() + ], help='Choose model type of asr task.') self.parser.add_argument( '--lang', @@ -192,23 +114,6 @@ class ASRExecutor(BaseExecutor): action='store_true', help='Increase logger verbosity of current task.') - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - """ - support_models = list(pretrained_models.keys()) - assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( - tag, '\n\t\t'.join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(pretrained_models[tag], - res_path) - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - 'Use pretrained model stored in: {}'.format(decompressed_path)) - - return decompressed_path - def _init_from_path(self, model_type: str='wenetspeech', lang: str='zh', @@ -219,6 +124,7 @@ class ASRExecutor(BaseExecutor): """ Init model and other resources from a specific path. """ + logger.info("start to init the model") if hasattr(self, 'model'): logger.info('Model had been initialized.') return @@ -228,19 +134,21 @@ class ASRExecutor(BaseExecutor): tag = model_type + '-' + lang + '-' + sample_rate_str res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - self.cfg_path = os.path.join(res_path, - pretrained_models[tag]['cfg_path']) + self.cfg_path = os.path.join( + res_path, self.pretrained_models[tag]['cfg_path']) self.ckpt_path = os.path.join( - res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams") + res_path, + self.pretrained_models[tag]['ckpt_path'] + ".pdparams") logger.info(res_path) - logger.info(self.cfg_path) - logger.info(self.ckpt_path) + else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) - + logger.info(self.cfg_path) + logger.info(self.ckpt_path) + #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) @@ -255,8 +163,8 @@ class ASRExecutor(BaseExecutor): self.collate_fn_test = SpeechCollator.from_config(self.config) self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.vocab) - lm_url = pretrained_models[tag]['lm_url'] - lm_md5 = pretrained_models[tag]['lm_md5'] + lm_url = self.pretrained_models[tag]['lm_url'] + lm_md5 = self.pretrained_models[tag]['lm_md5'] self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) @@ -269,12 +177,11 @@ class ASRExecutor(BaseExecutor): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) self.config.decode.decoding_method = decode_method - else: raise Exception("wrong type") model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} - model_class = dynamic_import(model_name, model_alias) + model_class = dynamic_import(model_name, self.model_alias) model_conf = self.config model = model_class.from_config(model_conf) self.model = model @@ -347,12 +254,14 @@ class ASRExecutor(BaseExecutor): else: raise Exception("wrong type") + logger.info("audio feat process success") + @paddle.no_grad() def infer(self, model_type: str): """ Model inference and result stored in self.output. """ - + logger.info("start to infer the model to get the output") cfg = self.config.decode audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] @@ -369,17 +278,22 @@ class ASRExecutor(BaseExecutor): self._outputs["result"] = result_transcripts[0] elif "conformer" in model_type or "transformer" in model_type: - result_transcripts = self.model.decode( - audio, - audio_len, - text_feature=self.text_feature, - decoding_method=cfg.decoding_method, - beam_size=cfg.beam_size, - ctc_weight=cfg.ctc_weight, - decoding_chunk_size=cfg.decoding_chunk_size, - num_decoding_left_chunks=cfg.num_decoding_left_chunks, - simulate_streaming=cfg.simulate_streaming) - self._outputs["result"] = result_transcripts[0][0] + logger.info(f"we will use the transformer like model : {model_type}") + try: + result_transcripts = self.model.decode( + audio, + audio_len, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + beam_size=cfg.beam_size, + ctc_weight=cfg.ctc_weight, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + self._outputs["result"] = result_transcripts[0][0] + except Exception as e: + logger.exception(e) + else: raise Exception("invalid model name") diff --git a/paddlespeech/cli/asr/pretrained_models.py b/paddlespeech/cli/asr/pretrained_models.py new file mode 100644 index 00000000..cc52c751 --- /dev/null +++ b/paddlespeech/cli/asr/pretrained_models.py @@ -0,0 +1,97 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pretrained_models = { + # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". + # e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k". + # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: + # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" + "conformer_wenetspeech-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz', + 'md5': + '76cb19ed857e6623856b7cd7ebbfeda4', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/conformer/checkpoints/wenetspeech', + }, + "transformer_librispeech-en-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz', + 'md5': + '2c667da24922aad391eacafe37bc1660', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/transformer/checkpoints/avg_10', + }, + "deepspeech2offline_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz', + 'md5': + '932c3593d62fe5c741b59b31318aa314', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2/checkpoints/avg_1', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + "deepspeech2online_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz', + 'md5': + '23e16c69730a1cb5d735c98c83c21e16', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_1', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, + "deepspeech2offline_librispeech-en-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz', + 'md5': + 'f5666c81ad015c8de03aac2bc92e5762', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/deepspeech2/checkpoints/avg_1', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm', + 'lm_md5': + '099a601759d467cd0a8523ff939819c5' + }, +} + +model_alias = { + "deepspeech2offline": + "paddlespeech.s2t.models.ds2:DeepSpeech2Model", + "deepspeech2online": + "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", + "conformer": + "paddlespeech.s2t.models.u2:U2Model", + "conformer_online": + "paddlespeech.s2t.models.u2:U2Model", + "transformer": + "paddlespeech.s2t.models.u2:U2Model", + "wenetspeech": + "paddlespeech.s2t.models.u2:U2Model", +} diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index f56d8a57..1f637a8f 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -25,55 +25,23 @@ import yaml from ..executor import BaseExecutor from ..log import logger from ..utils import cli_register -from ..utils import download_and_decompress -from ..utils import MODEL_HOME from ..utils import stats_wrapper +from .pretrained_models import model_alias +from .pretrained_models import pretrained_models from paddleaudio import load from paddleaudio.features import LogMelSpectrogram from paddlespeech.s2t.utils.dynamic_import import dynamic_import __all__ = ['CLSExecutor'] -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". - # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: - # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" - "panns_cnn6-32k": { - 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz', - 'md5': '4cf09194a95df024fd12f84712cf0f9c', - 'cfg_path': 'panns.yaml', - 'ckpt_path': 'cnn6.pdparams', - 'label_file': 'audioset_labels.txt', - }, - "panns_cnn10-32k": { - 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz', - 'md5': 'cb8427b22176cc2116367d14847f5413', - 'cfg_path': 'panns.yaml', - 'ckpt_path': 'cnn10.pdparams', - 'label_file': 'audioset_labels.txt', - }, - "panns_cnn14-32k": { - 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz', - 'md5': 'e3b9b5614a1595001161d0ab95edee97', - 'cfg_path': 'panns.yaml', - 'ckpt_path': 'cnn14.pdparams', - 'label_file': 'audioset_labels.txt', - }, -} - -model_alias = { - "panns_cnn6": "paddlespeech.cls.models.panns:CNN6", - "panns_cnn10": "paddlespeech.cls.models.panns:CNN10", - "panns_cnn14": "paddlespeech.cls.models.panns:CNN14", -} - @cli_register( name='paddlespeech.cls', description='Audio classification infer command.') class CLSExecutor(BaseExecutor): def __init__(self): - super(CLSExecutor, self).__init__() + super().__init__() + self.model_alias = model_alias + self.pretrained_models = pretrained_models self.parser = argparse.ArgumentParser( prog='paddlespeech.cls', add_help=True) @@ -83,7 +51,9 @@ class CLSExecutor(BaseExecutor): '--model', type=str, default='panns_cnn14', - choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()], + choices=[ + tag[:tag.index('-')] for tag in self.pretrained_models.keys() + ], help='Choose model type of cls task.') self.parser.add_argument( '--config', @@ -121,23 +91,6 @@ class CLSExecutor(BaseExecutor): action='store_true', help='Increase logger verbosity of current task.') - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - """ - support_models = list(pretrained_models.keys()) - assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( - tag, '\n\t\t'.join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(pretrained_models[tag], - res_path) - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - 'Use pretrained model stored in: {}'.format(decompressed_path)) - - return decompressed_path - def _init_from_path(self, model_type: str='panns_cnn14', cfg_path: Optional[os.PathLike]=None, @@ -153,12 +106,12 @@ class CLSExecutor(BaseExecutor): if label_file is None or ckpt_path is None: tag = model_type + '-' + '32k' # panns_cnn14-32k self.res_path = self._get_pretrained_path(tag) - self.cfg_path = os.path.join(self.res_path, - pretrained_models[tag]['cfg_path']) - self.label_file = os.path.join(self.res_path, - pretrained_models[tag]['label_file']) - self.ckpt_path = os.path.join(self.res_path, - pretrained_models[tag]['ckpt_path']) + self.cfg_path = os.path.join( + self.res_path, self.pretrained_models[tag]['cfg_path']) + self.label_file = os.path.join( + self.res_path, self.pretrained_models[tag]['label_file']) + self.ckpt_path = os.path.join( + self.res_path, self.pretrained_models[tag]['ckpt_path']) else: self.cfg_path = os.path.abspath(cfg_path) self.label_file = os.path.abspath(label_file) @@ -175,7 +128,7 @@ class CLSExecutor(BaseExecutor): self._label_list.append(line.strip()) # model - model_class = dynamic_import(model_type, model_alias) + model_class = dynamic_import(model_type, self.model_alias) model_dict = paddle.load(self.ckpt_path) self.model = model_class(extract_embedding=False) self.model.set_state_dict(model_dict) diff --git a/paddlespeech/cli/cls/pretrained_models.py b/paddlespeech/cli/cls/pretrained_models.py new file mode 100644 index 00000000..1d66850a --- /dev/null +++ b/paddlespeech/cli/cls/pretrained_models.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pretrained_models = { + # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". + # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". + # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: + # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" + "panns_cnn6-32k": { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz', + 'md5': '4cf09194a95df024fd12f84712cf0f9c', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn6.pdparams', + 'label_file': 'audioset_labels.txt', + }, + "panns_cnn10-32k": { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz', + 'md5': 'cb8427b22176cc2116367d14847f5413', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn10.pdparams', + 'label_file': 'audioset_labels.txt', + }, + "panns_cnn14-32k": { + 'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz', + 'md5': 'e3b9b5614a1595001161d0ab95edee97', + 'cfg_path': 'panns.yaml', + 'ckpt_path': 'cnn14.pdparams', + 'label_file': 'audioset_labels.txt', + }, +} + +model_alias = { + "panns_cnn6": "paddlespeech.cls.models.panns:CNN6", + "panns_cnn10": "paddlespeech.cls.models.panns:CNN10", + "panns_cnn14": "paddlespeech.cls.models.panns:CNN14", +} diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index 064939a8..df0b6783 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -25,6 +25,8 @@ from typing import Union import paddle from .log import logger +from .utils import download_and_decompress +from .utils import MODEL_HOME class BaseExecutor(ABC): @@ -35,19 +37,8 @@ class BaseExecutor(ABC): def __init__(self): self._inputs = OrderedDict() self._outputs = OrderedDict() - - @abstractmethod - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - - Args: - tag (str): A tag of pretrained model. - - Returns: - os.PathLike: The path on which resources of pretrained model locate. - """ - pass + self.pretrained_models = OrderedDict() + self.model_alias = OrderedDict() @abstractmethod def _init_from_path(self, *args, **kwargs): @@ -227,3 +218,20 @@ class BaseExecutor(ABC): ] for l in loggers: l.disabled = True + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + support_models = list(self.pretrained_models.keys()) + assert tag in self.pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( + tag, '\n\t\t'.join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(self.pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + + return decompressed_path diff --git a/paddlespeech/cli/st/infer.py b/paddlespeech/cli/st/infer.py index e64fc57d..29d95f79 100644 --- a/paddlespeech/cli/st/infer.py +++ b/paddlespeech/cli/st/infer.py @@ -32,40 +32,24 @@ from ..utils import cli_register from ..utils import download_and_decompress from ..utils import MODEL_HOME from ..utils import stats_wrapper +from .pretrained_models import kaldi_bins +from .pretrained_models import model_alias +from .pretrained_models import pretrained_models from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ["STExecutor"] -pretrained_models = { - "fat_st_ted-en-zh": { - "url": - "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz", - "md5": - "d62063f35a16d91210a71081bd2dd557", - "cfg_path": - "model.yaml", - "ckpt_path": - "exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams", - } -} - -model_alias = {"fat_st": "paddlespeech.s2t.models.u2_st:U2STModel"} - -kaldi_bins = { - "url": - "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz", - "md5": - "c0682303b3f3393dbf6ed4c4e35a53eb", -} - @cli_register( name="paddlespeech.st", description="Speech translation infer command.") class STExecutor(BaseExecutor): def __init__(self): - super(STExecutor, self).__init__() + super().__init__() + self.model_alias = model_alias + self.pretrained_models = pretrained_models + self.kaldi_bins = kaldi_bins self.parser = argparse.ArgumentParser( prog="paddlespeech.st", add_help=True) @@ -75,7 +59,9 @@ class STExecutor(BaseExecutor): "--model", type=str, default="fat_st_ted", - choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()], + choices=[ + tag[:tag.index('-')] for tag in self.pretrained_models.keys() + ], help="Choose model type of st task.") self.parser.add_argument( "--src_lang", @@ -119,28 +105,11 @@ class STExecutor(BaseExecutor): action='store_true', help='Increase logger verbosity of current task.') - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - """ - support_models = list(pretrained_models.keys()) - assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( - tag, '\n\t\t'.join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(pretrained_models[tag], - res_path) - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - "Use pretrained model stored in: {}".format(decompressed_path)) - - return decompressed_path - def _set_kaldi_bins(self) -> os.PathLike: """ Download and returns kaldi_bins resources path of current task. """ - decompressed_path = download_and_decompress(kaldi_bins, MODEL_HOME) + decompressed_path = download_and_decompress(self.kaldi_bins, MODEL_HOME) decompressed_path = os.path.abspath(decompressed_path) logger.info("Kaldi_bins stored in: {}".format(decompressed_path)) if "LD_LIBRARY_PATH" in os.environ: @@ -197,7 +166,7 @@ class STExecutor(BaseExecutor): model_conf = self.config model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} - model_class = dynamic_import(model_name, model_alias) + model_class = dynamic_import(model_name, self.model_alias) self.model = model_class.from_config(model_conf) self.model.eval() diff --git a/paddlespeech/cli/st/pretrained_models.py b/paddlespeech/cli/st/pretrained_models.py new file mode 100644 index 00000000..cc7410d2 --- /dev/null +++ b/paddlespeech/cli/st/pretrained_models.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pretrained_models = { + "fat_st_ted-en-zh": { + "url": + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz", + "md5": + "d62063f35a16d91210a71081bd2dd557", + "cfg_path": + "model.yaml", + "ckpt_path": + "exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams", + } +} + +model_alias = {"fat_st": "paddlespeech.s2t.models.u2_st:U2STModel"} + +kaldi_bins = { + "url": + "https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz", + "md5": + "c0682303b3f3393dbf6ed4c4e35a53eb", +} diff --git a/paddlespeech/cli/stats/infer.py b/paddlespeech/cli/stats/infer.py index 4ef50449..7cf4f236 100644 --- a/paddlespeech/cli/stats/infer.py +++ b/paddlespeech/cli/stats/infer.py @@ -16,7 +16,6 @@ from typing import List from prettytable import PrettyTable -from ..log import logger from ..utils import cli_register from ..utils import stats_wrapper @@ -27,7 +26,8 @@ model_name_format = { 'cls': 'Model-Sample Rate', 'st': 'Model-Source language-Target language', 'text': 'Model-Task-Language', - 'tts': 'Model-Language' + 'tts': 'Model-Language', + 'vector': 'Model-Sample Rate' } @@ -36,18 +36,18 @@ model_name_format = { description='Get speech tasks support models list.') class StatsExecutor(): def __init__(self): - super(StatsExecutor, self).__init__() + super().__init__() self.parser = argparse.ArgumentParser( prog='paddlespeech.stats', add_help=True) + self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector'] self.parser.add_argument( '--task', type=str, default='asr', - choices=['asr', 'cls', 'st', 'text', 'tts'], + choices=self.task_choices, help='Choose speech task.', required=True) - self.task_choices = ['asr', 'cls', 'st', 'text', 'tts'] def show_support_models(self, pretrained_models: dict): fields = model_name_format[self.task].split("-") @@ -61,73 +61,15 @@ class StatsExecutor(): Command line entry. """ parser_args = self.parser.parse_args(argv) - self.task = parser_args.task - if self.task not in self.task_choices: - logger.error( - "Please input correct speech task, choices = ['asr', 'cls', 'st', 'text', 'tts']" - ) + has_exceptions = False + try: + self(parser_args.task) + except Exception as e: + has_exceptions = True + if has_exceptions: return False - - elif self.task == 'asr': - try: - from ..asr.infer import pretrained_models - logger.info( - "Here is the list of ASR pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - return True - except BaseException: - logger.error("Failed to get the list of ASR pretrained models.") - return False - - elif self.task == 'cls': - try: - from ..cls.infer import pretrained_models - logger.info( - "Here is the list of CLS pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - return True - except BaseException: - logger.error("Failed to get the list of CLS pretrained models.") - return False - - elif self.task == 'st': - try: - from ..st.infer import pretrained_models - logger.info( - "Here is the list of ST pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - return True - except BaseException: - logger.error("Failed to get the list of ST pretrained models.") - return False - - elif self.task == 'text': - try: - from ..text.infer import pretrained_models - logger.info( - "Here is the list of TEXT pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - return True - except BaseException: - logger.error( - "Failed to get the list of TEXT pretrained models.") - return False - - elif self.task == 'tts': - try: - from ..tts.infer import pretrained_models - logger.info( - "Here is the list of TTS pretrained models released by PaddleSpeech that can be used by command line and python API" - ) - self.show_support_models(pretrained_models) - return True - except BaseException: - logger.error("Failed to get the list of TTS pretrained models.") - return False + else: + return True @stats_wrapper def __call__( @@ -138,13 +80,12 @@ class StatsExecutor(): """ self.task = task if self.task not in self.task_choices: - print( - "Please input correct speech task, choices = ['asr', 'cls', 'st', 'text', 'tts']" - ) + print("Please input correct speech task, choices = " + str( + self.task_choices)) elif self.task == 'asr': try: - from ..asr.infer import pretrained_models + from ..asr.pretrained_models import pretrained_models print( "Here is the list of ASR pretrained models released by PaddleSpeech that can be used by command line and python API" ) @@ -154,7 +95,7 @@ class StatsExecutor(): elif self.task == 'cls': try: - from ..cls.infer import pretrained_models + from ..cls.pretrained_models import pretrained_models print( "Here is the list of CLS pretrained models released by PaddleSpeech that can be used by command line and python API" ) @@ -164,7 +105,7 @@ class StatsExecutor(): elif self.task == 'st': try: - from ..st.infer import pretrained_models + from ..st.pretrained_models import pretrained_models print( "Here is the list of ST pretrained models released by PaddleSpeech that can be used by command line and python API" ) @@ -174,7 +115,7 @@ class StatsExecutor(): elif self.task == 'text': try: - from ..text.infer import pretrained_models + from ..text.pretrained_models import pretrained_models print( "Here is the list of TEXT pretrained models released by PaddleSpeech that can be used by command line and python API" ) @@ -184,10 +125,22 @@ class StatsExecutor(): elif self.task == 'tts': try: - from ..tts.infer import pretrained_models + from ..tts.pretrained_models import pretrained_models print( "Here is the list of TTS pretrained models released by PaddleSpeech that can be used by command line and python API" ) self.show_support_models(pretrained_models) except BaseException: print("Failed to get the list of TTS pretrained models.") + + elif self.task == 'vector': + try: + from ..vector.pretrained_models import pretrained_models + print( + "Here is the list of Speaker Recognition pretrained models released by PaddleSpeech that can be used by command line and python API" + ) + self.show_support_models(pretrained_models) + except BaseException: + print( + "Failed to get the list of Speaker Recognition pretrained models." + ) diff --git a/paddlespeech/cli/text/infer.py b/paddlespeech/cli/text/infer.py index dcf306c6..69e62e4b 100644 --- a/paddlespeech/cli/text/infer.py +++ b/paddlespeech/cli/text/infer.py @@ -25,58 +25,21 @@ from ...s2t.utils.dynamic_import import dynamic_import from ..executor import BaseExecutor from ..log import logger from ..utils import cli_register -from ..utils import download_and_decompress -from ..utils import MODEL_HOME from ..utils import stats_wrapper +from .pretrained_models import model_alias +from .pretrained_models import pretrained_models +from .pretrained_models import tokenizer_alias __all__ = ['TextExecutor'] -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". - # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". - # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: - # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" - "ernie_linear_p7_wudao-punc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz', - 'md5': - '12283e2ddde1797c5d1e57036b512746', - 'cfg_path': - 'ckpt/model_config.json', - 'ckpt_path': - 'ckpt/model_state.pdparams', - 'vocab_file': - 'punc_vocab.txt', - }, - "ernie_linear_p3_wudao-punc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz', - 'md5': - '448eb2fdf85b6a997e7e652e80c51dd2', - 'cfg_path': - 'ckpt/model_config.json', - 'ckpt_path': - 'ckpt/model_state.pdparams', - 'vocab_file': - 'punc_vocab.txt', - }, -} - -model_alias = { - "ernie_linear_p7": "paddlespeech.text.models:ErnieLinear", - "ernie_linear_p3": "paddlespeech.text.models:ErnieLinear", -} - -tokenizer_alias = { - "ernie_linear_p7": "paddlenlp.transformers:ErnieTokenizer", - "ernie_linear_p3": "paddlenlp.transformers:ErnieTokenizer", -} - @cli_register(name='paddlespeech.text', description='Text infer command.') class TextExecutor(BaseExecutor): def __init__(self): - super(TextExecutor, self).__init__() + super().__init__() + self.model_alias = model_alias + self.pretrained_models = pretrained_models + self.tokenizer_alias = tokenizer_alias self.parser = argparse.ArgumentParser( prog='paddlespeech.text', add_help=True) @@ -92,7 +55,9 @@ class TextExecutor(BaseExecutor): '--model', type=str, default='ernie_linear_p7_wudao', - choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()], + choices=[ + tag[:tag.index('-')] for tag in self.pretrained_models.keys() + ], help='Choose model type of text task.') self.parser.add_argument( '--lang', @@ -131,23 +96,6 @@ class TextExecutor(BaseExecutor): action='store_true', help='Increase logger verbosity of current task.') - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - """ - support_models = list(pretrained_models.keys()) - assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( - tag, '\n\t\t'.join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(pretrained_models[tag], - res_path) - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - 'Use pretrained model stored in: {}'.format(decompressed_path)) - - return decompressed_path - def _init_from_path(self, task: str='punc', model_type: str='ernie_linear_p7_wudao', @@ -167,12 +115,12 @@ class TextExecutor(BaseExecutor): if cfg_path is None or ckpt_path is None or vocab_file is None: tag = '-'.join([model_type, task, lang]) self.res_path = self._get_pretrained_path(tag) - self.cfg_path = os.path.join(self.res_path, - pretrained_models[tag]['cfg_path']) - self.ckpt_path = os.path.join(self.res_path, - pretrained_models[tag]['ckpt_path']) - self.vocab_file = os.path.join(self.res_path, - pretrained_models[tag]['vocab_file']) + self.cfg_path = os.path.join( + self.res_path, self.pretrained_models[tag]['cfg_path']) + self.ckpt_path = os.path.join( + self.res_path, self.pretrained_models[tag]['ckpt_path']) + self.vocab_file = os.path.join( + self.res_path, self.pretrained_models[tag]['vocab_file']) else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path) @@ -187,8 +135,8 @@ class TextExecutor(BaseExecutor): self._punc_list.append(line.strip()) # model - model_class = dynamic_import(model_name, model_alias) - tokenizer_class = dynamic_import(model_name, tokenizer_alias) + model_class = dynamic_import(model_name, self.model_alias) + tokenizer_class = dynamic_import(model_name, self.tokenizer_alias) self.model = model_class( cfg_path=self.cfg_path, ckpt_path=self.ckpt_path) self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0') diff --git a/paddlespeech/cli/text/pretrained_models.py b/paddlespeech/cli/text/pretrained_models.py new file mode 100644 index 00000000..817d3caa --- /dev/null +++ b/paddlespeech/cli/text/pretrained_models.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pretrained_models = { + # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". + # e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k". + # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: + # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" + "ernie_linear_p7_wudao-punc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz', + 'md5': + '12283e2ddde1797c5d1e57036b512746', + 'cfg_path': + 'ckpt/model_config.json', + 'ckpt_path': + 'ckpt/model_state.pdparams', + 'vocab_file': + 'punc_vocab.txt', + }, + "ernie_linear_p3_wudao-punc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz', + 'md5': + '448eb2fdf85b6a997e7e652e80c51dd2', + 'cfg_path': + 'ckpt/model_config.json', + 'ckpt_path': + 'ckpt/model_state.pdparams', + 'vocab_file': + 'punc_vocab.txt', + }, +} + +model_alias = { + "ernie_linear_p7": "paddlespeech.text.models:ErnieLinear", + "ernie_linear_p3": "paddlespeech.text.models:ErnieLinear", +} + +tokenizer_alias = { + "ernie_linear_p7": "paddlenlp.transformers:ErnieTokenizer", + "ernie_linear_p3": "paddlenlp.transformers:ErnieTokenizer", +} diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index 1c3fb29f..1c719930 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -29,9 +29,9 @@ from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger from ..utils import cli_register -from ..utils import download_and_decompress -from ..utils import MODEL_HOME from ..utils import stats_wrapper +from .pretrained_models import model_alias +from .pretrained_models import pretrained_models from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend @@ -39,299 +39,14 @@ from paddlespeech.t2s.modules.normalizer import ZScore __all__ = ['TTSExecutor'] -pretrained_models = { - # speedyspeech - "speedyspeech_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip', - 'md5': - '6f6fa967b408454b6662c8c00c0027cb', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_30600.pdz', - 'speech_stats': - 'feats_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'tones_dict': - 'tone_id_map.txt', - }, - - # fastspeech2 - "fastspeech2_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', - 'md5': - '637d28a5e53aa60275612ba4393d5f22', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_76000.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "fastspeech2_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip', - 'md5': - 'ffed800c93deaf16ca9b3af89bfcd747', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_100000.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "fastspeech2_aishell3-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip', - 'md5': - 'f4dd4a5f49a4552b77981f544ab3392e', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_96400.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'speaker_dict': - 'speaker_id_map.txt', - }, - "fastspeech2_vctk-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip', - 'md5': - '743e5024ca1e17a88c5c271db9779ba4', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_66200.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - 'speaker_dict': - 'speaker_id_map.txt', - }, - # tacotron2 - "tacotron2_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip', - 'md5': - '0df4b6f0bcbe0d73c5ed6df8867ab91a', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_30600.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - "tacotron2_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip', - 'md5': - '6a5eddd81ae0e81d16959b97481135f3', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_60300.pdz', - 'speech_stats': - 'speech_stats.npy', - 'phones_dict': - 'phone_id_map.txt', - }, - - # pwgan - "pwgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip', - 'md5': - '2e481633325b5bdf0a3823c714d2c117', - 'config': - 'pwg_default.yaml', - 'ckpt': - 'pwg_snapshot_iter_400000.pdz', - 'speech_stats': - 'pwg_stats.npy', - }, - "pwgan_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip', - 'md5': - '53610ba9708fd3008ccaf8e99dacbaf0', - 'config': - 'pwg_default.yaml', - 'ckpt': - 'pwg_snapshot_iter_400000.pdz', - 'speech_stats': - 'pwg_stats.npy', - }, - "pwgan_aishell3-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip', - 'md5': - 'd7598fa41ad362d62f85ffc0f07e3d84', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1000000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "pwgan_vctk-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip', - 'md5': - 'b3da1defcde3e578be71eb284cb89f2c', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - # mb_melgan - "mb_melgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip', - 'md5': - 'ee5f0604e20091f0d495b6ec4618b90d', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1000000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - # style_melgan - "style_melgan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip', - 'md5': - '5de2d5348f396de0c966926b8c462755', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_1500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - # hifigan - "hifigan_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', - 'md5': - 'dd40a3d88dfcf64513fba2f0f961ada6', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "hifigan_ljspeech-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip', - 'md5': - '70e9131695decbca06a65fe51ed38a72', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "hifigan_aishell3-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip', - 'md5': - '3bb49bc75032ed12f79c00c8cc79a09a', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - "hifigan_vctk-en": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip', - 'md5': - '7da8f88359bca2457e705d924cf27bd4', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_2500000.pdz', - 'speech_stats': - 'feats_stats.npy', - }, - - # wavernn - "wavernn_csmsc-zh": { - 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip', - 'md5': - 'ee37b752f09bcba8f2af3b777ca38e13', - 'config': - 'default.yaml', - 'ckpt': - 'snapshot_iter_400000.pdz', - 'speech_stats': - 'feats_stats.npy', - } -} - -model_alias = { - # acoustic model - "speedyspeech": - "paddlespeech.t2s.models.speedyspeech:SpeedySpeech", - "speedyspeech_inference": - "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference", - "fastspeech2": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2", - "fastspeech2_inference": - "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", - "tacotron2": - "paddlespeech.t2s.models.tacotron2:Tacotron2", - "tacotron2_inference": - "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", - # voc - "pwgan": - "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", - "pwgan_inference": - "paddlespeech.t2s.models.parallel_wavegan:PWGInference", - "mb_melgan": - "paddlespeech.t2s.models.melgan:MelGANGenerator", - "mb_melgan_inference": - "paddlespeech.t2s.models.melgan:MelGANInference", - "style_melgan": - "paddlespeech.t2s.models.melgan:StyleMelGANGenerator", - "style_melgan_inference": - "paddlespeech.t2s.models.melgan:StyleMelGANInference", - "hifigan": - "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", - "hifigan_inference": - "paddlespeech.t2s.models.hifigan:HiFiGANInference", - "wavernn": - "paddlespeech.t2s.models.wavernn:WaveRNN", - "wavernn_inference": - "paddlespeech.t2s.models.wavernn:WaveRNNInference", -} - @cli_register( name='paddlespeech.tts', description='Text to Speech infer command.') class TTSExecutor(BaseExecutor): def __init__(self): super().__init__() + self.model_alias = model_alias + self.pretrained_models = pretrained_models self.parser = argparse.ArgumentParser( prog='paddlespeech.tts', add_help=True) @@ -449,22 +164,6 @@ class TTSExecutor(BaseExecutor): action='store_true', help='Increase logger verbosity of current task.') - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """ - Download and returns pretrained resources path of current task. - """ - support_models = list(pretrained_models.keys()) - assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( - tag, '\n\t\t'.join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(pretrained_models[tag], - res_path) - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - 'Use pretrained model stored in: {}'.format(decompressed_path)) - return decompressed_path - def _init_from_path( self, am: str='fastspeech2_csmsc', @@ -490,16 +189,15 @@ class TTSExecutor(BaseExecutor): if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: am_res_path = self._get_pretrained_path(am_tag) self.am_res_path = am_res_path - self.am_config = os.path.join(am_res_path, - pretrained_models[am_tag]['config']) + self.am_config = os.path.join( + am_res_path, self.pretrained_models[am_tag]['config']) self.am_ckpt = os.path.join(am_res_path, - pretrained_models[am_tag]['ckpt']) + self.pretrained_models[am_tag]['ckpt']) self.am_stat = os.path.join( - am_res_path, pretrained_models[am_tag]['speech_stats']) + am_res_path, self.pretrained_models[am_tag]['speech_stats']) # must have phones_dict in acoustic self.phones_dict = os.path.join( - am_res_path, pretrained_models[am_tag]['phones_dict']) - print("self.phones_dict:", self.phones_dict) + am_res_path, self.pretrained_models[am_tag]['phones_dict']) logger.info(am_res_path) logger.info(self.am_config) logger.info(self.am_ckpt) @@ -509,21 +207,20 @@ class TTSExecutor(BaseExecutor): self.am_stat = os.path.abspath(am_stat) self.phones_dict = os.path.abspath(phones_dict) self.am_res_path = os.path.dirname(os.path.abspath(self.am_config)) - print("self.phones_dict:", self.phones_dict) # for speedyspeech self.tones_dict = None - if 'tones_dict' in pretrained_models[am_tag]: + if 'tones_dict' in self.pretrained_models[am_tag]: self.tones_dict = os.path.join( - am_res_path, pretrained_models[am_tag]['tones_dict']) + am_res_path, self.pretrained_models[am_tag]['tones_dict']) if tones_dict: self.tones_dict = tones_dict # for multi speaker fastspeech2 self.speaker_dict = None - if 'speaker_dict' in pretrained_models[am_tag]: + if 'speaker_dict' in self.pretrained_models[am_tag]: self.speaker_dict = os.path.join( - am_res_path, pretrained_models[am_tag]['speaker_dict']) + am_res_path, self.pretrained_models[am_tag]['speaker_dict']) if speaker_dict: self.speaker_dict = speaker_dict @@ -532,12 +229,12 @@ class TTSExecutor(BaseExecutor): if voc_ckpt is None or voc_config is None or voc_stat is None: voc_res_path = self._get_pretrained_path(voc_tag) self.voc_res_path = voc_res_path - self.voc_config = os.path.join(voc_res_path, - pretrained_models[voc_tag]['config']) - self.voc_ckpt = os.path.join(voc_res_path, - pretrained_models[voc_tag]['ckpt']) + self.voc_config = os.path.join( + voc_res_path, self.pretrained_models[voc_tag]['config']) + self.voc_ckpt = os.path.join( + voc_res_path, self.pretrained_models[voc_tag]['ckpt']) self.voc_stat = os.path.join( - voc_res_path, pretrained_models[voc_tag]['speech_stats']) + voc_res_path, self.pretrained_models[voc_tag]['speech_stats']) logger.info(voc_res_path) logger.info(self.voc_config) logger.info(self.voc_ckpt) @@ -588,8 +285,9 @@ class TTSExecutor(BaseExecutor): # model: {model_name}_{dataset} am_name = am[:am.rindex('_')] - am_class = dynamic_import(am_name, model_alias) - am_inference_class = dynamic_import(am_name + '_inference', model_alias) + am_class = dynamic_import(am_name, self.model_alias) + am_inference_class = dynamic_import(am_name + '_inference', + self.model_alias) if am_name == 'fastspeech2': am = am_class( @@ -618,9 +316,9 @@ class TTSExecutor(BaseExecutor): # vocoder # model: {model_name}_{dataset} voc_name = voc[:voc.rindex('_')] - voc_class = dynamic_import(voc_name, model_alias) + voc_class = dynamic_import(voc_name, self.model_alias) voc_inference_class = dynamic_import(voc_name + '_inference', - model_alias) + self.model_alias) if voc_name != 'wavernn': voc = voc_class(**self.voc_config["generator_params"]) voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"]) @@ -735,7 +433,6 @@ class TTSExecutor(BaseExecutor): am_ckpt = args.am_ckpt am_stat = args.am_stat phones_dict = args.phones_dict - print("phones_dict:", phones_dict) tones_dict = args.tones_dict speaker_dict = args.speaker_dict voc = args.voc diff --git a/paddlespeech/cli/tts/pretrained_models.py b/paddlespeech/cli/tts/pretrained_models.py new file mode 100644 index 00000000..65254a93 --- /dev/null +++ b/paddlespeech/cli/tts/pretrained_models.py @@ -0,0 +1,300 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pretrained_models = { + # speedyspeech + "speedyspeech_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip', + 'md5': + '6f6fa967b408454b6662c8c00c0027cb', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_30600.pdz', + 'speech_stats': + 'feats_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'tones_dict': + 'tone_id_map.txt', + }, + + # fastspeech2 + "fastspeech2_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', + 'md5': + '637d28a5e53aa60275612ba4393d5f22', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_76000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + "fastspeech2_ljspeech-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip', + 'md5': + 'ffed800c93deaf16ca9b3af89bfcd747', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_100000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + "fastspeech2_aishell3-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip', + 'md5': + 'f4dd4a5f49a4552b77981f544ab3392e', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_96400.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'speaker_dict': + 'speaker_id_map.txt', + }, + "fastspeech2_vctk-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip', + 'md5': + '743e5024ca1e17a88c5c271db9779ba4', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_66200.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'speaker_dict': + 'speaker_id_map.txt', + }, + # tacotron2 + "tacotron2_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip', + 'md5': + '0df4b6f0bcbe0d73c5ed6df8867ab91a', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_30600.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + "tacotron2_ljspeech-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip', + 'md5': + '6a5eddd81ae0e81d16959b97481135f3', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_60300.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + + # pwgan + "pwgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip', + 'md5': + '2e481633325b5bdf0a3823c714d2c117', + 'config': + 'pwg_default.yaml', + 'ckpt': + 'pwg_snapshot_iter_400000.pdz', + 'speech_stats': + 'pwg_stats.npy', + }, + "pwgan_ljspeech-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip', + 'md5': + '53610ba9708fd3008ccaf8e99dacbaf0', + 'config': + 'pwg_default.yaml', + 'ckpt': + 'pwg_snapshot_iter_400000.pdz', + 'speech_stats': + 'pwg_stats.npy', + }, + "pwgan_aishell3-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip', + 'md5': + 'd7598fa41ad362d62f85ffc0f07e3d84', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1000000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + "pwgan_vctk-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip', + 'md5': + 'b3da1defcde3e578be71eb284cb89f2c', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + # mb_melgan + "mb_melgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'ee5f0604e20091f0d495b6ec4618b90d', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1000000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + # style_melgan + "style_melgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip', + 'md5': + '5de2d5348f396de0c966926b8c462755', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + # hifigan + "hifigan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'dd40a3d88dfcf64513fba2f0f961ada6', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + "hifigan_ljspeech-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip', + 'md5': + '70e9131695decbca06a65fe51ed38a72', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + "hifigan_aishell3-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip', + 'md5': + '3bb49bc75032ed12f79c00c8cc79a09a', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + "hifigan_vctk-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip', + 'md5': + '7da8f88359bca2457e705d924cf27bd4', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + + # wavernn + "wavernn_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip', + 'md5': + 'ee37b752f09bcba8f2af3b777ca38e13', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_400000.pdz', + 'speech_stats': + 'feats_stats.npy', + } +} + +model_alias = { + # acoustic model + "speedyspeech": + "paddlespeech.t2s.models.speedyspeech:SpeedySpeech", + "speedyspeech_inference": + "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference", + "fastspeech2": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2", + "fastspeech2_inference": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", + "tacotron2": + "paddlespeech.t2s.models.tacotron2:Tacotron2", + "tacotron2_inference": + "paddlespeech.t2s.models.tacotron2:Tacotron2Inference", + # voc + "pwgan": + "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", + "pwgan_inference": + "paddlespeech.t2s.models.parallel_wavegan:PWGInference", + "mb_melgan": + "paddlespeech.t2s.models.melgan:MelGANGenerator", + "mb_melgan_inference": + "paddlespeech.t2s.models.melgan:MelGANInference", + "style_melgan": + "paddlespeech.t2s.models.melgan:StyleMelGANGenerator", + "style_melgan_inference": + "paddlespeech.t2s.models.melgan:StyleMelGANInference", + "hifigan": + "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", + "hifigan_inference": + "paddlespeech.t2s.models.hifigan:HiFiGANInference", + "wavernn": + "paddlespeech.t2s.models.wavernn:WaveRNN", + "wavernn_inference": + "paddlespeech.t2s.models.wavernn:WaveRNNInference", +} diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 68e832ac..1dff6edb 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -27,45 +27,24 @@ from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger from ..utils import cli_register -from ..utils import download_and_decompress -from ..utils import MODEL_HOME from ..utils import stats_wrapper +from .pretrained_models import model_alias +from .pretrained_models import pretrained_models from paddleaudio.backends import load as load_audio from paddleaudio.compliance.librosa import melspectrogram from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.modules.sid_model import SpeakerIdetification -pretrained_models = { - # The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]". - # e.g. "ecapatdnn_voxceleb12-16k". - # Command line and python api use "{model_name}[-{dataset}]" as --model, usage: - # "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav" - "ecapatdnn_voxceleb12-16k": { - 'url': - 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz', - 'md5': - 'cc33023c54ab346cd318408f43fcaf95', - 'cfg_path': - 'conf/model.yaml', # the yaml config path - 'ckpt_path': - 'model/model', # the format is ${dir}/{model_name}, - # so the first 'model' is dir, the second 'model' is the name - # this means we have a model stored as model/model.pdparams - }, -} - -model_alias = { - "ecapatdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn", -} - @cli_register( name="paddlespeech.vector", description="Speech to vector embedding infer command.") class VectorExecutor(BaseExecutor): def __init__(self): - super(VectorExecutor, self).__init__() + super().__init__() + self.model_alias = model_alias + self.pretrained_models = pretrained_models self.parser = argparse.ArgumentParser( prog="paddlespeech.vector", add_help=True) @@ -128,8 +107,8 @@ class VectorExecutor(BaseExecutor): Returns: bool: - False: some audio occurs error - True: all audio process success + False: some audio occurs error + True: all audio process success """ # stage 0: parse the args and get the required args parser_args = self.parser.parse_args(argv) @@ -289,32 +268,6 @@ class VectorExecutor(BaseExecutor): return res - def _get_pretrained_path(self, tag: str) -> os.PathLike: - """get the neural network path from the pretrained model list - we stored all the pretained mode in the variable `pretrained_models` - - Args: - tag (str): model tag in the pretrained model list - - Returns: - os.PathLike: the downloaded pretrained model path in the disk - """ - support_models = list(pretrained_models.keys()) - assert tag in pretrained_models, \ - 'The model "{}" you want to use has not been supported,'\ - 'please choose other models.\n' \ - 'The support models includes\n\t\t{}'.format(tag, "\n\t\t".join(support_models)) - - res_path = os.path.join(MODEL_HOME, tag) - decompressed_path = download_and_decompress(pretrained_models[tag], - res_path) - - decompressed_path = os.path.abspath(decompressed_path) - logger.info( - 'Use pretrained model stored in: {}'.format(decompressed_path)) - - return decompressed_path - def _init_from_path(self, model_type: str='ecapatdnn_voxceleb12', sample_rate: int=16000, @@ -350,10 +303,11 @@ class VectorExecutor(BaseExecutor): res_path = self._get_pretrained_path(tag) self.res_path = res_path - self.cfg_path = os.path.join(res_path, - pretrained_models[tag]['cfg_path']) + self.cfg_path = os.path.join( + res_path, self.pretrained_models[tag]['cfg_path']) self.ckpt_path = os.path.join( - res_path, pretrained_models[tag]['ckpt_path'] + '.pdparams') + res_path, + self.pretrained_models[tag]['ckpt_path'] + '.pdparams') else: # get the model from disk self.cfg_path = os.path.abspath(cfg_path) @@ -373,7 +327,7 @@ class VectorExecutor(BaseExecutor): logger.info("start to dynamic import the model class") model_name = model_type[:model_type.rindex('_')] logger.info(f"model name {model_name}") - model_class = dynamic_import(model_name, model_alias) + model_class = dynamic_import(model_name, self.model_alias) model_conf = self.config.model backbone = model_class(**model_conf) model = SpeakerIdetification( diff --git a/paddlespeech/cli/vector/pretrained_models.py b/paddlespeech/cli/vector/pretrained_models.py new file mode 100644 index 00000000..686a22d8 --- /dev/null +++ b/paddlespeech/cli/vector/pretrained_models.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pretrained_models = { + # The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]". + # e.g. "ecapatdnn_voxceleb12-16k". + # Command line and python api use "{model_name}[-{dataset}]" as --model, usage: + # "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav" + "ecapatdnn_voxceleb12-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz', + 'md5': + 'cc33023c54ab346cd318408f43fcaf95', + 'cfg_path': + 'conf/model.yaml', # the yaml config path + 'ckpt_path': + 'model/model', # the format is ${dir}/{model_name}, + # so the first 'model' is dir, the second 'model' is the name + # this means we have a model stored as model/model.pdparams + }, +} + +model_alias = { + "ecapatdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn", +} diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 6a98607b..9b66126e 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -279,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer): # TODO(Hui Zhang): if end_flag.sum() == running_size: if end_flag.cast(paddle.int64).sum() == running_size: break - + # 2.1 Forward decoder step hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( running_size, 1, 1).to(device) # (B*N, i, i) # logp: (B*N, vocab) logp, cache = self.decoder.forward_one_step( encoder_out, encoder_mask, hyps, hyps_mask, cache) - # 2.2 First beam prune: select topk best prob at current time top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) top_k_logp = mask_finished_scores(top_k_logp, end_flag) @@ -708,11 +707,11 @@ class U2BaseModel(ASRInterface, nn.Layer): batch_size = feats.shape[0] if decoding_method in ['ctc_prefix_beam_search', 'attention_rescoring'] and batch_size > 1: - logger.fatal( + logger.error( f'decoding mode {decoding_method} must be running with batch_size == 1' ) + logger.error(f"current batch_size is {batch_size}") sys.exit(1) - if decoding_method == 'attention': hyps = self.recognize( feats, diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index 33ad472d..1bb15873 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase): # init once if self._ext_scorer is not None: return - + if language_model_path != '': logger.info("begin to initialize the external scorer " "for decoding") diff --git a/paddlespeech/server/README.md b/paddlespeech/server/README.md index 819fe440..3ac68dae 100644 --- a/paddlespeech/server/README.md +++ b/paddlespeech/server/README.md @@ -35,3 +35,16 @@ ```bash paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav ``` + + ## Online ASR Server + +### Lanuch online asr server +``` +paddlespeech_server start --config_file conf/ws_conformer_application.yaml +``` + +### Access online asr server + +``` +paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav +``` \ No newline at end of file diff --git a/paddlespeech/server/README_cn.md b/paddlespeech/server/README_cn.md index c0a4a733..5f235313 100644 --- a/paddlespeech/server/README_cn.md +++ b/paddlespeech/server/README_cn.md @@ -35,3 +35,17 @@ ```bash paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav ``` + +## 流式ASR + +### 启动流式语音识别服务 + +``` +paddlespeech_server start --config_file conf/ws_conformer_application.yaml +``` + +### 访问流式语音识别服务 + +``` +paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav +``` \ No newline at end of file diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index cb802ce5..45469178 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -277,11 +277,12 @@ class ASRClientExecutor(BaseExecutor): lang=lang, audio_format=audio_format) time_end = time.time() - logger.info(res.json()) + logger.info(res) logger.info("Response time %f s." % (time_end - time_start)) return True except Exception as e: logger.error("Failed to speech recognition.") + logger.error(e) return False @stats_wrapper @@ -299,9 +300,10 @@ class ASRClientExecutor(BaseExecutor): logging.info("asr websocket client start") handler = ASRAudioHandler(server_ip, port) loop = asyncio.get_event_loop() - loop.run_until_complete(handler.run(input)) + res = loop.run_until_complete(handler.run(input)) logging.info("asr websocket client finished") + return res['asr_results'] @cli_client_register( name='paddlespeech_client.cls', description='visit cls service') diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index b958bdf6..dee8d78b 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -41,11 +41,7 @@ asr_online: shift_ms: 40 sample_rate: 16000 sample_width: 2 - - vad_conf: - aggressiveness: 2 - sample_rate: 16000 - frame_duration_ms: 20 - sample_width: 2 - padding_ms: 200 - padding_ratio: 0.9 + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 20 # ms + shift_ms: 10 # ms diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml new file mode 100644 index 00000000..e14833de --- /dev/null +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -0,0 +1,45 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 0.0.0.0 +port: 8090 + +# The task format in the engin_list is: _ +# task choices = ['asr_online', 'tts_online'] +# protocol = ['websocket', 'http'] (only one can be selected). +# websocket only support online engine type. +protocol: 'websocket' +engine_list: ['asr_online'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### ASR ######################################### +################### speech task: asr; engine_type: online ####################### +asr_online: + model_type: 'conformer_online_multicn' + am_model: # the pdmodel file of am static model [optional] + am_params: # the pdiparams file of am static model [optional] + lang: 'zh' + sample_rate: 16000 + cfg_path: + decode_method: + force_yes: True + + am_predictor_conf: + device: # set 'gpu:id' or 'cpu' + switch_ir_optim: True + glog_info: False # True -> print glog + summary: True # False -> do not show predictor config + + chunk_buffer_conf: + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 25 # ms + shift_ms: 10 # ms + sample_rate: 16000 + sample_width: 2 \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index ca82b615..758cbaab 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import os from typing import Optional @@ -20,12 +21,19 @@ from numpy import float32 from yacs.config import CfgNode from paddlespeech.cli.asr.infer import ASRExecutor +from paddlespeech.cli.asr.infer import model_alias from paddlespeech.cli.log import logger +from paddlespeech.cli.utils import download_and_decompress from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder +from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.s2t.utils.tensor_utils import add_sos_eos +from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor @@ -35,9 +43,9 @@ __all__ = ['ASREngine'] pretrained_models = { "deepspeech2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz', 'md5': - 'd5e076217cf60486519f72c217d21b9b', + '23e16c69730a1cb5d735c98c83c21e16', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -51,16 +59,543 @@ pretrained_models = { 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' }, + "conformer_online_multicn-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', + 'md5': + '0ac93d390552336f2a906aec9e33c5fa', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/chunk_conformer/checkpoints/multi_cn', + 'model': + 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', + 'params': + 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, } +# ASR server connection process class +class PaddleASRConnectionHanddler: + def __init__(self, asr_engine): + """Init a Paddle ASR Connection Handler instance + + Args: + asr_engine (ASREngine): the global asr engine + """ + super().__init__() + logger.info( + "create an paddle asr connection handler to process the websocket connection" + ) + self.config = asr_engine.config + self.model_config = asr_engine.executor.config + self.asr_engine = asr_engine + + self.init() + self.reset() + + def init(self): + # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer + self.model_type = self.asr_engine.executor.model_type + self.sample_rate = self.asr_engine.executor.sample_rate + # tokens to text + self.text_feature = self.asr_engine.executor.text_feature + + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + from paddlespeech.s2t.io.collator import SpeechCollator + self.am_predictor = self.asr_engine.executor.am_predictor + + self.collate_fn_test = SpeechCollator.from_config(self.model_config) + self.decoder = CTCDecoder( + odim=self.model_config.output_dim, # is in vocab + enc_n_units=self.model_config.rnn_layer_size * 2, + blank_id=self.model_config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.model_config.get('ctc_grad_norm_type', + None)) + + cfg = self.model_config.decode + decode_batch_size = 1 # for online + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + # frame window samples length and frame shift samples length + + self.win_length = int(self.model_config.window_ms / 1000 * + self.sample_rate) + self.n_shift = int(self.model_config.stride_ms / 1000 * + self.sample_rate) + + elif "conformer" in self.model_type or "transformer" in self.model_type: + # acoustic model + self.model = self.asr_engine.executor.model + + # ctc decoding config + self.ctc_decode_config = self.asr_engine.executor.config.decode + self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) + + # extract feat, new only fbank in conformer model + self.preprocess_conf = self.model_config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + + # frame window samples length and frame shift samples length + self.win_length = self.preprocess_conf.process[0]['win_length'] + self.n_shift = self.preprocess_conf.process[0]['n_shift'] + + def extract_feat(self, samples): + if "deepspeech2online" in self.model_type: + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 + + # pcm16 -> pcm 32 + # pcm2float will change the orignal samples, + # so we shoule do pcm2float before concatenate + samples = pcm2float(samples) + + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The connection remain the audio samples: {self.remained_wav.shape}" + ) + + # read audio + speech_segment = SpeechSegment.from_pcm( + self.remained_wav, self.sample_rate, transcript=" ") + # audio augment + self.collate_fn_test.augmentation.transform_audio(speech_segment) + + # extract speech feature + spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( + speech_segment, self.collate_fn_test.keep_transcription_text) + # CMVN spectrum + if self.collate_fn_test._normalizer: + spectrum = self.collate_fn_test._normalizer.apply(spectrum) + + # spectrum augment + audio = self.collate_fn_test.augmentation.transform_feature( + spectrum) + + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + # audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + + if self.cached_feat is None: + self.cached_feat = audio + else: + assert (len(audio.shape) == 3) + assert (len(self.cached_feat.shape) == 3) + self.cached_feat = paddle.concat( + [self.cached_feat, audio], axis=1) + + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + + self.num_frames += audio_len + self.remained_wav = self.remained_wav[self.n_shift * audio_len:] + + logger.info( + f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" + ) + elif "conformer_online" in self.model_type: + logger.info("Online ASR extract the feat") + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 + + logger.info(f"This package receive {samples.shape[0]} pcm data") + self.num_samples += samples.shape[0] + + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The connection remain the audio samples: {self.remained_wav.shape}" + ) + if len(self.remained_wav) < self.win_length: + return 0 + + # fbank + x_chunk = self.preprocessing(self.remained_wav, + **self.preprocess_args) + x_chunk = paddle.to_tensor( + x_chunk, dtype="float32").unsqueeze(axis=0) + if self.cached_feat is None: + self.cached_feat = x_chunk + else: + assert (len(x_chunk.shape) == 3) + assert (len(self.cached_feat.shape) == 3) + self.cached_feat = paddle.concat( + [self.cached_feat, x_chunk], axis=1) + + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + + num_frames = x_chunk.shape[1] + self.num_frames += num_frames + self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + + logger.info( + f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" + ) + # logger.info(f"accumulate samples: {self.num_samples}") + + def reset(self): + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + # for deepspeech2 + self.chunk_state_h_box = copy.deepcopy( + self.asr_engine.executor.chunk_state_h_box) + self.chunk_state_c_box = copy.deepcopy( + self.asr_engine.executor.chunk_state_c_box) + self.decoder.reset_decoder(batch_size=1) + + # for conformer online + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_out = None + self.cached_feat = None + self.remained_wav = None + self.offset = 0 + self.num_samples = 0 + self.device = None + self.hyps = [] + self.num_frames = 0 + self.chunk_num = 0 + self.global_frame_offset = 0 + self.result_transcripts = [''] + + def decode(self, is_finished=False): + if "deepspeech2online" in self.model_type: + # x_chunk 是特征数据 + decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model + context = 7 # context=7 in deepspeech2 model + subsampling = 4 # subsampling=4 in deepspeech2 model + stride = subsampling * decoding_chunk_size + cached_feature_num = context - subsampling + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + + if self.cached_feat is None: + logger.info("no audio feat, please input more pcm data") + return + + num_frames = self.cached_feat.shape[1] + logger.info( + f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" + ) + # the cached feat must be larger decoding_window + if num_frames < decoding_window and not is_finished: + logger.info( + f"frame feat num is less than {decoding_window}, please input more pcm data" + ) + return None, None + + # if is_finished=True, we need at least context frames + if num_frames < context: + logger.info( + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + ) + return None, None + logger.info("start to do model forward") + # num_frames - context + 1 ensure that current frame can get context window + if is_finished: + # if get the finished chunk, we need process the last context + left_frames = context + else: + # we only process decoding_window frames for one chunk + left_frames = decoding_window + + for cur in range(0, num_frames - left_frames + 1, stride): + end = min(cur + decoding_window, num_frames) + # extract the audio + x_chunk = self.cached_feat[:, cur:end, :].numpy() + x_chunk_lens = np.array([x_chunk.shape[1]]) + trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) + + self.result_transcripts = [trans_best] + + self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] + # return trans_best[0] + elif "conformer" in self.model_type or "transformer" in self.model_type: + try: + logger.info( + f"we will use the transformer like model : {self.model_type}" + ) + self.advance_decoding(is_finished) + self.update_result() + + except Exception as e: + logger.exception(e) + else: + raise Exception("invalid model name") + + @paddle.no_grad() + def decode_one_chunk(self, x_chunk, x_chunk_lens): + logger.info("start to decoce one chunk with deepspeech2 model") + input_names = self.am_predictor.get_input_names() + audio_handle = self.am_predictor.get_input_handle(input_names[0]) + audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) + h_box_handle = self.am_predictor.get_input_handle(input_names[2]) + c_box_handle = self.am_predictor.get_input_handle(input_names[3]) + + audio_handle.reshape(x_chunk.shape) + audio_handle.copy_from_cpu(x_chunk) + + audio_len_handle.reshape(x_chunk_lens.shape) + audio_len_handle.copy_from_cpu(x_chunk_lens) + + h_box_handle.reshape(self.chunk_state_h_box.shape) + h_box_handle.copy_from_cpu(self.chunk_state_h_box) + + c_box_handle.reshape(self.chunk_state_c_box.shape) + c_box_handle.copy_from_cpu(self.chunk_state_c_box) + + output_names = self.am_predictor.get_output_names() + output_handle = self.am_predictor.get_output_handle(output_names[0]) + output_lens_handle = self.am_predictor.get_output_handle( + output_names[1]) + output_state_h_handle = self.am_predictor.get_output_handle( + output_names[2]) + output_state_c_handle = self.am_predictor.get_output_handle( + output_names[3]) + + self.am_predictor.run() + + output_chunk_probs = output_handle.copy_to_cpu() + output_chunk_lens = output_lens_handle.copy_to_cpu() + self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() + self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() + + self.decoder.next(output_chunk_probs, output_chunk_lens) + trans_best, trans_beam = self.decoder.decode() + logger.info(f"decode one best result: {trans_best[0]}") + return trans_best[0] + + @paddle.no_grad() + def advance_decoding(self, is_finished=False): + logger.info("start to decode with advanced_decoding method") + cfg = self.ctc_decode_config + decoding_chunk_size = cfg.decoding_chunk_size + num_decoding_left_chunks = cfg.num_decoding_left_chunks + + assert decoding_chunk_size > 0 + subsampling = self.model.encoder.embed.subsampling_rate + context = self.model.encoder.embed.right_context + 1 + stride = subsampling * decoding_chunk_size + cached_feature_num = context - subsampling # processed chunk feature cached for next chunk + + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + if self.cached_feat is None: + logger.info("no audio feat, please input more pcm data") + return + + num_frames = self.cached_feat.shape[1] + logger.info( + f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" + ) + + # the cached feat must be larger decoding_window + if num_frames < decoding_window and not is_finished: + logger.info( + f"frame feat num is less than {decoding_window}, please input more pcm data" + ) + return None, None + + # if is_finished=True, we need at least context frames + if num_frames < context: + logger.info( + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + ) + return None, None + + logger.info("start to do model forward") + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + outputs = [] + + # num_frames - context + 1 ensure that current frame can get context window + if is_finished: + # if get the finished chunk, we need process the last context + left_frames = context + else: + # we only process decoding_window frames for one chunk + left_frames = decoding_window + + # record the end for removing the processed feat + end = None + for cur in range(0, num_frames - left_frames + 1, stride): + end = min(cur + decoding_window, num_frames) + + self.chunk_num += 1 + chunk_xs = self.cached_feat[:, cur:end, :] + (y, self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + chunk_xs, self.offset, required_cache_size, + self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) + outputs.append(y) + + # update the offset + self.offset += y.shape[1] + + ys = paddle.cat(outputs, 1) + if self.encoder_out is None: + self.encoder_out = ys + else: + self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) + + # get the ctc probs + ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + + self.searcher.search(ctc_probs, self.cached_feat.place) + + self.hyps = self.searcher.get_one_best_hyps() + assert self.cached_feat.shape[0] == 1 + assert end >= cached_feature_num + + self.cached_feat = self.cached_feat[0, end - + cached_feature_num:, :].unsqueeze(0) + assert len( + self.cached_feat.shape + ) == 3, f"current cache feat shape is: {self.cached_feat.shape}" + + logger.info( + f"This connection handler encoder out shape: {self.encoder_out.shape}" + ) + + def update_result(self): + logger.info("update the final result") + hyps = self.hyps + self.result_transcripts = [ + self.text_feature.defeaturize(hyp) for hyp in hyps + ] + self.result_tokenids = [hyp for hyp in hyps] + + def get_result(self): + if len(self.result_transcripts) > 0: + return self.result_transcripts[0] + else: + return '' + + @paddle.no_grad() + def rescoring(self): + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + return + + logger.info("rescoring the final result") + if "attention_rescoring" != self.ctc_decode_config.decoding_method: + return + + self.searcher.finalize_search() + self.update_result() + + beam_size = self.ctc_decode_config.beam_size + hyps = self.searcher.get_hyps() + if hyps is None or len(hyps) == 0: + return + + # assert len(hyps) == beam_size + hyp_list = [] + for hyp in hyps: + hyp_content = hyp[0] + # Prevent the hyp is empty + if len(hyp_content) == 0: + hyp_content = (self.model.ctc.blank_id, ) + hyp_content = paddle.to_tensor( + hyp_content, place=self.device, dtype=paddle.long) + hyp_list.append(hyp_content) + hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) + hyps_lens = paddle.to_tensor( + [len(hyp[0]) for hyp in hyps], place=self.device, + dtype=paddle.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, + self.model.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + + encoder_out = self.encoder_out.repeat(beam_size, 1, 1) + encoder_mask = paddle.ones( + (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) + decoder_out, _ = self.model.decoder( + encoder_out, encoder_mask, hyps_pad, + hyps_lens) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain + decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) + decoder_out = decoder_out.numpy() + + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. + score += decoder_out[i][len(hyp[0])][self.model.eos] + # add ctc score (which in ln domain) + score += hyp[1] * self.ctc_decode_config.ctc_weight + if score > best_score: + best_score = score + best_index = i + + # update the one best result + logger.info(f"best index: {best_index}") + self.hyps = [hyps[best_index][0]] + self.update_result() + + class ASRServerExecutor(ASRExecutor): def __init__(self): super().__init__() pass + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + support_models = list(pretrained_models.keys()) + assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( + tag, '\n\t\t'.join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + + return decompressed_path + def _init_from_path(self, - model_type: str='wenetspeech', + model_type: str='deepspeech2online_aishell', am_model: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None, lang: str='zh', @@ -71,12 +606,15 @@ class ASRServerExecutor(ASRExecutor): """ Init model and other resources from a specific path. """ - + self.model_type = model_type + self.sample_rate = sample_rate if cfg_path is None or am_model is None or am_params is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str + logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path + self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) @@ -85,9 +623,6 @@ class ASRServerExecutor(ASRExecutor): self.am_params = os.path.join(res_path, pretrained_models[tag]['params']) logger.info(res_path) - logger.info(self.cfg_path) - logger.info(self.am_model) - logger.info(self.am_params) else: self.cfg_path = os.path.abspath(cfg_path) self.am_model = os.path.abspath(am_model) @@ -95,6 +630,10 @@ class ASRServerExecutor(ASRExecutor): self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) + logger.info(self.cfg_path) + logger.info(self.am_model) + logger.info(self.am_params) + #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) @@ -112,59 +651,107 @@ class ASRServerExecutor(ASRExecutor): lm_url = pretrained_models[tag]['lm_url'] lm_md5 = pretrained_models[tag]['lm_md5'] + logger.info(f"Start to load language model {lm_url}") self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) - elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: - raise Exception("wrong type") + elif "conformer" in model_type or "transformer" in model_type: + logger.info("start to create the stream conformer asr engine") + if self.config.spm_model_prefix: + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) + self.vocab = self.config.vocab_filepath + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + # update the decoding method + if decode_method: + self.config.decode.decoding_method = decode_method + + # we only support ctc_prefix_beam_search and attention_rescoring dedoding method + # Generally we set the decoding_method to attention_rescoring + if self.config.decode.decoding_method not in [ + "ctc_prefix_beam_search", "attention_rescoring" + ]: + logger.info( + "we set the decoding_method to attention_rescoring") + self.config.decode.decoding = "attention_rescoring" + assert self.config.decode.decoding_method in [ + "ctc_prefix_beam_search", "attention_rescoring" + ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" else: raise Exception("wrong type") + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + # AM predictor + logger.info("ASR engine start to init the am predictor") + self.am_predictor_conf = am_predictor_conf + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) - # AM predictor - self.am_predictor_conf = am_predictor_conf - self.am_predictor = init_predictor( - model_file=self.am_model, - params_file=self.am_params, - predictor_conf=self.am_predictor_conf) - - # decoder - self.decoder = CTCDecoder( - odim=self.config.output_dim, # is in vocab - enc_n_units=self.config.rnn_layer_size * 2, - blank_id=self.config.blank_id, - dropout_rate=0.0, - reduction=True, # sum - batch_average=True, # sum / batch_size - grad_norm_type=self.config.get('ctc_grad_norm_type', None)) - - # init decoder - cfg = self.config.decode - decode_batch_size = 1 # for online - self.decoder.init_decoder( - decode_batch_size, self.text_feature.vocab_list, - cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, - cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, - cfg.num_proc_bsearch) - - # init state box - self.chunk_state_h_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - self.chunk_state_c_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) + # decoder + logger.info("ASR engine start to create the ctc decoder instance") + self.decoder = CTCDecoder( + odim=self.config.output_dim, # is in vocab + enc_n_units=self.config.rnn_layer_size * 2, + blank_id=self.config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.config.get('ctc_grad_norm_type', None)) + + # init decoder + logger.info("ASR engine start to init the ctc decoder") + cfg = self.config.decode + decode_batch_size = 1 # for online + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + + # init state box + self.chunk_state_h_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + elif "conformer" in model_type or "transformer" in model_type: + model_name = model_type[:model_type.rindex( + '_')] # model_type: {model_name}_{dataset} + logger.info(f"model name: {model_name}") + model_class = dynamic_import(model_name, model_alias) + model_conf = self.config + model = model_class.from_config(model_conf) + self.model = model + self.model.eval() + + # load model + model_dict = paddle.load(self.am_model) + self.model.set_state_dict(model_dict) + logger.info("create the transformer like model success") + + # update the ctc decoding + self.searcher = CTCPrefixBeamSearch(self.config.decode) + self.transformer_decode_reset() def reset_decoder_and_chunk(self): """reset decoder and chunk state for an new audio """ - self.decoder.reset_decoder(batch_size=1) - # init state box, for new audio request - self.chunk_state_h_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - self.chunk_state_c_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + self.decoder.reset_decoder(batch_size=1) + # init state box, for new audio request + self.chunk_state_h_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + elif "conformer" in self.model_type or "transformer" in self.model_type: + self.transformer_decode_reset() def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): """decode one chunk @@ -175,8 +762,9 @@ class ASRServerExecutor(ASRExecutor): model_type (str): online model type Returns: - [type]: [description] + str: one best result """ + logger.info("start to decoce chunk by chunk") if "deepspeech2online" in model_type: input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) @@ -215,14 +803,142 @@ class ASRServerExecutor(ASRExecutor): self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() - + logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] elif "conformer" in model_type or "transformer" in model_type: - raise Exception("invalid model name") + try: + logger.info( + f"we will use the transformer like model : {self.model_type}" + ) + self.advanced_decoding(x_chunk, x_chunk_lens) + self.update_result() + + return self.result_transcripts[0] + except Exception as e: + logger.exception(e) else: raise Exception("invalid model name") + def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): + logger.info("start to decode with advanced_decoding method") + encoder_out, encoder_mask = self.encoder_forward(xs) + ctc_probs = self.model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + self.searcher.search(ctc_probs, xs.place) + # update the one best result + self.hyps = self.searcher.get_one_best_hyps() + + # now we supprot ctc_prefix_beam_search and attention_rescoring + if "attention_rescoring" in self.config.decode.decoding_method: + self.rescoring(encoder_out, xs.place) + + def encoder_forward(self, xs): + logger.info("get the model out from the feat") + cfg = self.config.decode + decoding_chunk_size = cfg.decoding_chunk_size + num_decoding_left_chunks = cfg.num_decoding_left_chunks + + assert decoding_chunk_size > 0 + subsampling = self.model.encoder.embed.subsampling_rate + context = self.model.encoder.embed.right_context + 1 + stride = subsampling * decoding_chunk_size + + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.shape[1] + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + logger.info("start to do model forward") + outputs = [] + + # num_frames - context + 1 ensure that current frame can get context window + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + (y, self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + chunk_xs, self.offset, required_cache_size, + self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) + outputs.append(y) + self.offset += y.shape[1] + + ys = paddle.cat(outputs, 1) + masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) + masks = masks.unsqueeze(1) + return ys, masks + + def rescoring(self, encoder_out, device): + logger.info("start to rescoring the hyps") + beam_size = self.config.decode.beam_size + hyps = self.searcher.get_hyps() + assert len(hyps) == beam_size + + hyp_list = [] + for hyp in hyps: + hyp_content = hyp[0] + # Prevent the hyp is empty + if len(hyp_content) == 0: + hyp_content = (self.model.ctc.blank_id, ) + hyp_content = paddle.to_tensor( + hyp_content, place=device, dtype=paddle.long) + hyp_list.append(hyp_content) + hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) + hyps_lens = paddle.to_tensor( + [len(hyp[0]) for hyp in hyps], place=device, + dtype=paddle.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, + self.model.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + + encoder_out = encoder_out.repeat(beam_size, 1, 1) + encoder_mask = paddle.ones( + (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) + decoder_out, _ = self.model.decoder( + encoder_out, encoder_mask, hyps_pad, + hyps_lens) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain + decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) + decoder_out = decoder_out.numpy() + + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. + score += decoder_out[i][len(hyp[0])][self.model.eos] + # add ctc score (which in ln domain) + score += hyp[1] * self.config.decode.ctc_weight + if score > best_score: + best_score = score + best_index = i + + # update the one best result + self.hyps = [hyps[best_index][0]] + return hyps[best_index][0] + + def transformer_decode_reset(self): + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.offset = 0 + # decoding reset + self.searcher.reset() + + def update_result(self): + logger.info("update the final result") + hyps = self.hyps + self.result_transcripts = [ + self.text_feature.defeaturize(hyp) for hyp in hyps + ] + self.result_tokenids = [hyp for hyp in hyps] + def extract_feat(self, samples, sample_rate): """extract feat @@ -234,34 +950,58 @@ class ASRServerExecutor(ASRExecutor): x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] """ - # pcm16 -> pcm 32 - samples = pcm2float(samples) - # read audio - speech_segment = SpeechSegment.from_pcm( - samples, sample_rate, transcript=" ") - # audio augment - self.collate_fn_test.augmentation.transform_audio(speech_segment) + if "deepspeech2online" in self.model_type: + # pcm16 -> pcm 32 + samples = pcm2float(samples) + # read audio + speech_segment = SpeechSegment.from_pcm( + samples, sample_rate, transcript=" ") + # audio augment + self.collate_fn_test.augmentation.transform_audio(speech_segment) - # extract speech feature - spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( - speech_segment, self.collate_fn_test.keep_transcription_text) - # CMVN spectrum - if self.collate_fn_test._normalizer: - spectrum = self.collate_fn_test._normalizer.apply(spectrum) + # extract speech feature + spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( + speech_segment, self.collate_fn_test.keep_transcription_text) + # CMVN spectrum + if self.collate_fn_test._normalizer: + spectrum = self.collate_fn_test._normalizer.apply(spectrum) - # spectrum augment - audio = self.collate_fn_test.augmentation.transform_feature(spectrum) + # spectrum augment + audio = self.collate_fn_test.augmentation.transform_feature( + spectrum) - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - # audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + # audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) - x_chunk = audio.numpy() - x_chunk_lens = np.array([audio_len]) + x_chunk = audio.numpy() + x_chunk_lens = np.array([audio_len]) - return x_chunk, x_chunk_lens + return x_chunk, x_chunk_lens + elif "conformer_online" in self.model_type: + + if sample_rate != self.sample_rate: + logger.info(f"audio sample rate {sample_rate} is not match," + "the model sample_rate is {self.sample_rate}") + logger.info(f"ASR Engine use the {self.model_type} to process") + logger.info("Create the preprocess instance") + preprocess_conf = self.config.preprocess_config + preprocess_args = {"train": False} + preprocessing = Transformation(preprocess_conf) + + logger.info("Read the audio file") + logger.info(f"audio shape: {samples.shape}") + # fbank + x_chunk = preprocessing(samples, **preprocess_args) + x_chunk_lens = paddle.to_tensor(x_chunk.shape[0]) + x_chunk = paddle.to_tensor( + x_chunk, dtype="float32").unsqueeze(axis=0) + logger.info( + f"process the audio feature success, feat shape: {x_chunk.shape}" + ) + return x_chunk, x_chunk_lens class ASREngine(BaseEngine): @@ -273,6 +1013,7 @@ class ASREngine(BaseEngine): def __init__(self): super(ASREngine, self).__init__() + logger.info("create the online asr engine instance") def init(self, config: dict) -> bool: """init engine resource @@ -301,7 +1042,10 @@ class ASREngine(BaseEngine): logger.info("Initialize ASR server engine successfully.") return True - def preprocess(self, samples, sample_rate): + def preprocess(self, + samples, + sample_rate, + model_type="deepspeech2online_aishell-zh-16k"): """preprocess Args: @@ -312,6 +1056,7 @@ class ASREngine(BaseEngine): x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] """ + # if "deepspeech" in model_type: x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate) return x_chunk, x_chunk_lens diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py new file mode 100644 index 00000000..8aee0a50 --- /dev/null +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +import paddle +from paddlespeech.cli.log import logger +from paddlespeech.s2t.utils.utility import log_add + +__all__ = ['CTCPrefixBeamSearch'] + + +class CTCPrefixBeamSearch: + def __init__(self, config): + """Implement the ctc prefix beam search + + Args: + config (yacs.config.CfgNode): _description_ + """ + self.config = config + self.reset() + + @paddle.no_grad() + def search(self, ctc_probs, device, blank_id=0): + """ctc prefix beam search method decode a chunk feature + + Args: + xs (paddle.Tensor): feature data + ctc_probs (paddle.Tensor): the ctc probability of all the tokens + device (paddle.fluid.core_avx.Place): the feature host device, such as CUDAPlace(0). + blank_id (int, optional): the blank id in the vocab. Defaults to 0. + + Returns: + list: the search result + """ + # decode + logger.info("start to ctc prefix search") + + batch_size = 1 + beam_size = self.config.beam_size + maxlen = ctc_probs.shape[0] + + assert len(ctc_probs.shape) == 2 + + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + # blank_ending_score and none_blank_ending_score in ln domain + if self.cur_hyps is None: + self.cur_hyps = [(tuple(), (0.0, -float('inf')))] + # 2. CTC beam search step by step + for t in range(0, maxlen): + logp = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + + # 2.1 First beam prune: select topk best + # do token passing process + top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + for s in top_k_index: + s = s.item() + ps = logp[s].item() + for prefix, (pb, pnb) in self.cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == blank_id: # blank + n_pb, n_pnb = next_hyps[prefix] + n_pb = log_add([n_pb, pb + ps, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + elif s == last: + # Update *ss -> *s; + n_pb, n_pnb = next_hyps[prefix] + n_pnb = log_add([n_pnb, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + + # 2.2 Second beam prune + next_hyps = sorted( + next_hyps.items(), + key=lambda x: log_add(list(x[1])), + reverse=True) + self.cur_hyps = next_hyps[:beam_size] + + self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] + logger.info("ctc prefix search success") + return self.hyps + + def get_one_best_hyps(self): + """Return the one best result + + Returns: + list: the one best result + """ + return [self.hyps[0][0]] + + def get_hyps(self): + """Return the search hyps + + Returns: + list: return the search hyps + """ + return self.hyps + + def reset(self): + """Rest the search cache value + """ + self.cur_hyps = None + self.hyps = None + + def finalize_search(self): + """do nothing in ctc_prefix_beam_search + """ + pass diff --git a/paddlespeech/server/engine/tts/online/tts_engine.py b/paddlespeech/server/engine/tts/online/tts_engine.py index 25a8bc76..c9135b88 100644 --- a/paddlespeech/server/engine/tts/online/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/tts_engine.py @@ -12,24 +12,329 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 +import math +import os import time +from typing import Optional import numpy as np import paddle +import yaml +from yacs.config import CfgNode from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import float2pcm +from paddlespeech.server.utils.util import denorm from paddlespeech.server.utils.util import get_chunks +from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.zh_frontend import Frontend +from paddlespeech.t2s.modules.normalizer import ZScore + +__all__ = ['TTSEngine'] + +# support online model +pretrained_models = { + # fastspeech2 + "fastspeech2_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', + 'md5': + '637d28a5e53aa60275612ba4393d5f22', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_76000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + "fastspeech2_cnndecoder_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip', + 'md5': + '6eb28e22ace73e0ebe7845f86478f89f', + 'config': + 'cnndecoder.yaml', + 'ckpt': + 'snapshot_iter_153000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + + # mb_melgan + "mb_melgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'ee5f0604e20091f0d495b6ec4618b90d', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1000000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + + # hifigan + "hifigan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'dd40a3d88dfcf64513fba2f0f961ada6', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, +} + +model_alias = { + # acoustic model + "fastspeech2": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2", + "fastspeech2_inference": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", + + # voc + "mb_melgan": + "paddlespeech.t2s.models.melgan:MelGANGenerator", + "mb_melgan_inference": + "paddlespeech.t2s.models.melgan:MelGANInference", + "hifigan": + "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", + "hifigan_inference": + "paddlespeech.t2s.models.hifigan:HiFiGANInference", +} __all__ = ['TTSEngine'] class TTSServerExecutor(TTSExecutor): - def __init__(self): + def __init__(self, am_block, am_pad, voc_block, voc_pad): super().__init__() - pass + self.am_block = am_block + self.am_pad = am_pad + self.voc_block = voc_block + self.voc_pad = voc_pad + + def get_model_info(self, + field: str, + model_name: str, + ckpt: Optional[os.PathLike], + stat: Optional[os.PathLike]): + """get model information + + Args: + field (str): am or voc + model_name (str): model type, support fastspeech2, higigan, mb_melgan + ckpt (Optional[os.PathLike]): ckpt file + stat (Optional[os.PathLike]): stat file, including mean and standard deviation + + Returns: + [module]: model module + [Tensor]: mean + [Tensor]: standard deviation + """ + + model_class = dynamic_import(model_name, model_alias) + + if field == "am": + odim = self.am_config.n_mels + model = model_class( + idim=self.vocab_size, odim=odim, **self.am_config["model"]) + model.set_state_dict(paddle.load(ckpt)["main_params"]) + + elif field == "voc": + model = model_class(**self.voc_config["generator_params"]) + model.set_state_dict(paddle.load(ckpt)["generator_params"]) + model.remove_weight_norm() + + else: + logger.error("Please set correct field, am or voc") + + model.eval() + model_mu, model_std = np.load(stat) + model_mu = paddle.to_tensor(model_mu) + model_std = paddle.to_tensor(model_std) + + return model, model_mu, model_std + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + support_models = list(pretrained_models.keys()) + assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( + tag, '\n\t\t'.join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + return decompressed_path + + def _init_from_path( + self, + am: str='fastspeech2_csmsc', + am_config: Optional[os.PathLike]=None, + am_ckpt: Optional[os.PathLike]=None, + am_stat: Optional[os.PathLike]=None, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, + voc: str='mb_melgan_csmsc', + voc_config: Optional[os.PathLike]=None, + voc_ckpt: Optional[os.PathLike]=None, + voc_stat: Optional[os.PathLike]=None, + lang: str='zh', ): + """ + Init model and other resources from a specific path. + """ + if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'): + logger.info('Models had been initialized.') + return + # am model info + am_tag = am + '-' + lang + if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: + am_res_path = self._get_pretrained_path(am_tag) + self.am_res_path = am_res_path + self.am_config = os.path.join(am_res_path, + pretrained_models[am_tag]['config']) + self.am_ckpt = os.path.join(am_res_path, + pretrained_models[am_tag]['ckpt']) + self.am_stat = os.path.join( + am_res_path, pretrained_models[am_tag]['speech_stats']) + # must have phones_dict in acoustic + self.phones_dict = os.path.join( + am_res_path, pretrained_models[am_tag]['phones_dict']) + print("self.phones_dict:", self.phones_dict) + logger.info(am_res_path) + logger.info(self.am_config) + logger.info(self.am_ckpt) + else: + self.am_config = os.path.abspath(am_config) + self.am_ckpt = os.path.abspath(am_ckpt) + self.am_stat = os.path.abspath(am_stat) + self.phones_dict = os.path.abspath(phones_dict) + self.am_res_path = os.path.dirname(os.path.abspath(self.am_config)) + print("self.phones_dict:", self.phones_dict) + + self.tones_dict = None + self.speaker_dict = None + + # voc model info + voc_tag = voc + '-' + lang + if voc_ckpt is None or voc_config is None or voc_stat is None: + voc_res_path = self._get_pretrained_path(voc_tag) + self.voc_res_path = voc_res_path + self.voc_config = os.path.join(voc_res_path, + pretrained_models[voc_tag]['config']) + self.voc_ckpt = os.path.join(voc_res_path, + pretrained_models[voc_tag]['ckpt']) + self.voc_stat = os.path.join( + voc_res_path, pretrained_models[voc_tag]['speech_stats']) + logger.info(voc_res_path) + logger.info(self.voc_config) + logger.info(self.voc_ckpt) + else: + self.voc_config = os.path.abspath(voc_config) + self.voc_ckpt = os.path.abspath(voc_ckpt) + self.voc_stat = os.path.abspath(voc_stat) + self.voc_res_path = os.path.dirname( + os.path.abspath(self.voc_config)) + + # Init body. + with open(self.am_config) as f: + self.am_config = CfgNode(yaml.safe_load(f)) + with open(self.voc_config) as f: + self.voc_config = CfgNode(yaml.safe_load(f)) + + with open(self.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + self.vocab_size = len(phn_id) + print("vocab_size:", self.vocab_size) + + # frontend + if lang == 'zh': + self.frontend = Frontend( + phone_vocab_path=self.phones_dict, + tone_vocab_path=self.tones_dict) + + elif lang == 'en': + self.frontend = English(phone_vocab_path=self.phones_dict) + print("frontend done!") + + # am infer info + self.am_name = am[:am.rindex('_')] + if self.am_name == "fastspeech2_cnndecoder": + self.am_inference, self.am_mu, self.am_std = self.get_model_info( + "am", "fastspeech2", self.am_ckpt, self.am_stat) + else: + am, am_mu, am_std = self.get_model_info("am", self.am_name, + self.am_ckpt, self.am_stat) + am_normalizer = ZScore(am_mu, am_std) + am_inference_class = dynamic_import(self.am_name + '_inference', + model_alias) + self.am_inference = am_inference_class(am_normalizer, am) + self.am_inference.eval() + print("acoustic model done!") + + # voc infer info + self.voc_name = voc[:voc.rindex('_')] + voc, voc_mu, voc_std = self.get_model_info("voc", self.voc_name, + self.voc_ckpt, self.voc_stat) + voc_normalizer = ZScore(voc_mu, voc_std) + voc_inference_class = dynamic_import(self.voc_name + '_inference', + model_alias) + self.voc_inference = voc_inference_class(voc_normalizer, voc) + self.voc_inference.eval() + print("voc done!") + + def get_phone(self, sentence, lang, merge_sentences, get_tone_ids): + tone_ids = None + if lang == 'zh': + input_ids = self.frontend.get_input_ids( + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + elif lang == 'en': + input_ids = self.frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + + def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): + """ + Streaming inference removes the result of pad inference + """ + front_pad = min(chunk_id * block, pad) + # first chunk + if chunk_id == 0: + data = data[:block * upsample] + # last chunk + elif chunk_id == chunk_num - 1: + data = data[front_pad * upsample:] + # middle chunk + else: + data = data[front_pad * upsample:(front_pad + block) * upsample] + + return data @paddle.no_grad() def infer( @@ -37,16 +342,20 @@ class TTSServerExecutor(TTSExecutor): text: str, lang: str='zh', am: str='fastspeech2_csmsc', - spk_id: int=0, - am_block: int=42, - am_pad: int=12, - voc_block: int=14, - voc_pad: int=14, ): + spk_id: int=0, ): """ Model inference and result stored in self.output. """ - am_name = am[:am.rindex('_')] - am_dataset = am[am.rindex('_') + 1:] + + am_block = self.am_block + am_pad = self.am_pad + am_upsample = 1 + voc_block = self.voc_block + voc_pad = self.voc_pad + voc_upsample = self.voc_config.n_shift + # first_flag 用于标记首包 + first_flag = 1 + get_tone_ids = False merge_sentences = False frontend_st = time.time() @@ -64,43 +373,100 @@ class TTSServerExecutor(TTSExecutor): phone_ids = input_ids["phone_ids"] else: print("lang should in {'zh', 'en'}!") - self.frontend_time = time.time() - frontend_st + frontend_et = time.time() + self.frontend_time = frontend_et - frontend_st for i in range(len(phone_ids)): - am_st = time.time() part_phone_ids = phone_ids[i] - # am - if am_name == 'speedyspeech': - part_tone_ids = tone_ids[i] - mel = self.am_inference(part_phone_ids, part_tone_ids) - # fastspeech2 + voc_chunk_id = 0 + + # fastspeech2_csmsc + if am == "fastspeech2_csmsc": + # am + mel = self.am_inference(part_phone_ids) + if first_flag == 1: + first_am_et = time.time() + self.first_am_infer = first_am_et - frontend_et + + # voc streaming + mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") + voc_chunk_num = len(mel_chunks) + voc_st = time.time() + for i, mel_chunk in enumerate(mel_chunks): + sub_wav = self.voc_inference(mel_chunk) + sub_wav = self.depadding(sub_wav, voc_chunk_num, i, + voc_block, voc_pad, voc_upsample) + if first_flag == 1: + first_voc_et = time.time() + self.first_voc_infer = first_voc_et - first_am_et + self.first_response_time = first_voc_et - frontend_st + first_flag = 0 + + yield sub_wav + + # fastspeech2_cnndecoder_csmsc + elif am == "fastspeech2_cnndecoder_csmsc": + # am + orig_hs, h_masks = self.am_inference.encoder_infer( + part_phone_ids) + + # streaming voc chunk info + mel_len = orig_hs.shape[1] + voc_chunk_num = math.ceil(mel_len / self.voc_block) + start = 0 + end = min(self.voc_block + self.voc_pad, mel_len) + + # streaming am + hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am") + am_chunk_num = len(hss) + for i, hs in enumerate(hss): + before_outs, _ = self.am_inference.decoder(hs) + after_outs = before_outs + self.am_inference.postnet( + before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + normalized_mel = after_outs[0] + sub_mel = denorm(normalized_mel, self.am_mu, self.am_std) + sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block, + am_pad, am_upsample) + + if i == 0: + mel_streaming = sub_mel + else: + mel_streaming = np.concatenate( + (mel_streaming, sub_mel), axis=0) + + # streaming voc + # 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理 + while (mel_streaming.shape[0] >= end and + voc_chunk_id < voc_chunk_num): + if first_flag == 1: + first_am_et = time.time() + self.first_am_infer = first_am_et - frontend_et + voc_chunk = mel_streaming[start:end, :] + voc_chunk = paddle.to_tensor(voc_chunk) + sub_wav = self.voc_inference(voc_chunk) + + sub_wav = self.depadding(sub_wav, voc_chunk_num, + voc_chunk_id, voc_block, + voc_pad, voc_upsample) + if first_flag == 1: + first_voc_et = time.time() + self.first_voc_infer = first_voc_et - first_am_et + self.first_response_time = first_voc_et - frontend_st + first_flag = 0 + + yield sub_wav + + voc_chunk_id += 1 + start = max(0, voc_chunk_id * voc_block - voc_pad) + end = min((voc_chunk_id + 1) * voc_block + voc_pad, + mel_len) + else: - # multi speaker - if am_dataset in {"aishell3", "vctk"}: - mel = self.am_inference( - part_phone_ids, spk_id=paddle.to_tensor(spk_id)) - else: - mel = self.am_inference(part_phone_ids) - am_et = time.time() - - # voc streaming - voc_upsample = self.voc_config.n_shift - mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") - chunk_num = len(mel_chunks) - voc_st = time.time() - for i, mel_chunk in enumerate(mel_chunks): - sub_wav = self.voc_inference(mel_chunk) - front_pad = min(i * voc_block, voc_pad) - - if i == 0: - sub_wav = sub_wav[:voc_block * voc_upsample] - elif i == chunk_num - 1: - sub_wav = sub_wav[front_pad * voc_upsample:] - else: - sub_wav = sub_wav[front_pad * voc_upsample:( - front_pad + voc_block) * voc_upsample] - - yield sub_wav + logger.error( + "Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts." + ) + + self.final_response_time = time.time() - frontend_st class TTSEngine(BaseEngine): @@ -113,14 +479,21 @@ class TTSEngine(BaseEngine): def __init__(self, name=None): """Initialize TTS server engine """ - super(TTSEngine, self).__init__() + super().__init__() def init(self, config: dict) -> bool: - self.executor = TTSServerExecutor() self.config = config - assert "fastspeech2_csmsc" in config.am and ( - config.voc == "hifigan_csmsc-zh" or config.voc == "mb_melgan_csmsc" + assert ( + config.am == "fastspeech2_csmsc" or + config.am == "fastspeech2_cnndecoder_csmsc" + ) and ( + config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc" ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' + + assert ( + config.voc_block > 0 and config.voc_pad > 0 + ), "Please set correct voc_block and voc_pad, they should be more than 0." + try: if self.config.device: self.device = self.config.device @@ -135,6 +508,9 @@ class TTSEngine(BaseEngine): (self.device)) return False + self.executor = TTSServerExecutor(config.am_block, config.am_pad, + config.voc_block, config.voc_pad) + try: self.executor._init_from_path( am=self.config.am, @@ -155,15 +531,42 @@ class TTSEngine(BaseEngine): (self.device)) return False - self.am_block = self.config.am_block - self.am_pad = self.config.am_pad - self.voc_block = self.config.voc_block - self.voc_pad = self.config.voc_pad - logger.info("Initialize TTS server engine successfully on device: %s." % (self.device)) + + # warm up + try: + self.warm_up() + except Exception as e: + logger.error("Failed to warm up on tts engine.") + return False + return True + def warm_up(self): + """warm up + """ + if self.config.lang == 'zh': + sentence = "您好,欢迎使用语音合成服务。" + if self.config.lang == 'en': + sentence = "Hello and welcome to the speech synthesis service." + logger.info( + "*******************************warm up ********************************" + ) + for i in range(3): + for wav in self.executor.infer( + text=sentence, + lang=self.config.lang, + am=self.config.am, + spk_id=0, ): + logger.info( + f"The first response time of the {i} warm up: {self.executor.first_response_time} s" + ) + break + logger.info( + "**********************************************************************" + ) + def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): # Convert byte to text if text_bese64: @@ -195,18 +598,14 @@ class TTSEngine(BaseEngine): wav_base64: The base64 format of the synthesized audio. """ - lang = self.config.lang wav_list = [] for wav in self.executor.infer( text=sentence, - lang=lang, + lang=self.config.lang, am=self.config.am, - spk_id=spk_id, - am_block=self.am_block, - am_pad=self.am_pad, - voc_block=self.voc_block, - voc_pad=self.voc_pad): + spk_id=spk_id, ): + # wav type: float32, convert to pcm (base64) wav = float2pcm(wav) # float32 to int16 wav_bytes = wav.tobytes() # to bytes @@ -216,5 +615,14 @@ class TTSEngine(BaseEngine): yield wav_base64 wav_all = np.concatenate(wav_list, axis=0) - logger.info("The durations of audio is: {} s".format( - len(wav_all) / self.executor.am_config.fs)) + duration = len(wav_all) / self.executor.am_config.fs + logger.info(f"sentence: {sentence}") + logger.info(f"The durations of audio is: {duration} s") + logger.info( + f"first response time: {self.executor.first_response_time} s") + logger.info( + f"final response time: {self.executor.final_response_time} s") + logger.info(f"RTF: {self.executor.final_response_time / duration}") + logger.info( + f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s," + ) diff --git a/paddlespeech/server/tests/__init__.py b/paddlespeech/server/tests/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/__init__.py b/paddlespeech/server/tests/asr/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/offline/__init__.py b/paddlespeech/server/tests/asr/offline/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/offline/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/online/__init__.py b/paddlespeech/server/tests/asr/online/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/online/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/online/websocket_client.py b/paddlespeech/server/tests/asr/online/websocket_client.py index 01b19405..49cbd703 100644 --- a/paddlespeech/server/tests/asr/online/websocket_client.py +++ b/paddlespeech/server/tests/asr/online/websocket_client.py @@ -34,10 +34,9 @@ class ASRAudioHandler: def read_wave(self, wavfile_path: str): samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') x_len = len(samples) - # chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz - chunk_size = 80 * 16 #80ms, sample_rate = 16kHz - if x_len % chunk_size != 0: + chunk_size = 85 * 16 #80ms, sample_rate = 16kHz + if x_len % chunk_size!= 0: padding_len_x = chunk_size - x_len % chunk_size else: padding_len_x = 0 @@ -48,7 +47,6 @@ class ASRAudioHandler: assert (x_len + padding_len_x) % chunk_size == 0 num_chunk = (x_len + padding_len_x) / chunk_size num_chunk = int(num_chunk) - for i in range(0, num_chunk): start = i * chunk_size end = start + chunk_size @@ -57,7 +55,11 @@ class ASRAudioHandler: async def run(self, wavfile_path: str): logging.info("send a message to the server") + # self.read_wave() + # send websocket handshake protocal async with websockets.connect(self.url) as ws: + # server has already received handshake protocal + # client start to send the command audio_info = json.dumps( { "name": "test.wav", @@ -78,7 +80,6 @@ class ASRAudioHandler: msg = json.loads(msg) logging.info("receive msg={}".format(msg)) - result = msg # finished audio_info = json.dumps( { @@ -91,10 +92,12 @@ class ASRAudioHandler: separators=(',', ': ')) await ws.send(audio_info) msg = await ws.recv() + + # decode the bytes to str msg = json.loads(msg) - logging.info("receive msg={}".format(msg)) - - return result + logging.info("final receive msg={}".format(msg)) + result = msg + return result def main(args): diff --git a/paddlespeech/server/utils/buffer.py b/paddlespeech/server/utils/buffer.py index 12b1f0e5..d4e6cd49 100644 --- a/paddlespeech/server/utils/buffer.py +++ b/paddlespeech/server/utils/buffer.py @@ -63,12 +63,12 @@ class ChunkBuffer(object): the sample rate. Yields Frames of the requested duration. """ + audio = self.remained_audio + audio self.remained_audio = b'' offset = 0 timestamp = 0.0 - while offset + self.window_bytes <= len(audio): yield Frame(audio[offset:offset + self.window_bytes], timestamp, self.window_sec) diff --git a/paddlespeech/server/utils/util.py b/paddlespeech/server/utils/util.py index 0fe70849..72ee0060 100644 --- a/paddlespeech/server/utils/util.py +++ b/paddlespeech/server/utils/util.py @@ -52,6 +52,10 @@ def get_chunks(data, block_size, pad_size, step): Returns: list: chunks list """ + + if block_size == -1: + return [data] + if step == "am": data_len = data.shape[1] elif step == "voc": diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 03a49b48..a865703d 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -13,12 +13,12 @@ # limitations under the License. import json -import numpy as np from fastapi import APIRouter from fastapi import WebSocket from fastapi import WebSocketDisconnect from starlette.websockets import WebSocketState as WebSocketState +from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.vad import VADAudio @@ -28,26 +28,29 @@ router = APIRouter() @router.websocket('/ws/asr') async def websocket_endpoint(websocket: WebSocket): - await websocket.accept() engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] + connection_handler = None # init buffer + # each websocekt connection has its own chunk buffer chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer = ChunkBuffer( - window_n=7, - shift_n=4, - window_ms=20, - shift_ms=10, - sample_rate=chunk_buffer_conf['sample_rate'], - sample_width=chunk_buffer_conf['sample_width']) + window_n=chunk_buffer_conf.window_n, + shift_n=chunk_buffer_conf.shift_n, + window_ms=chunk_buffer_conf.window_ms, + shift_ms=chunk_buffer_conf.shift_ms, + sample_rate=chunk_buffer_conf.sample_rate, + sample_width=chunk_buffer_conf.sample_width) + # init vad - vad_conf = asr_engine.config.vad_conf - vad = VADAudio( - aggressiveness=vad_conf['aggressiveness'], - rate=vad_conf['sample_rate'], - frame_duration_ms=vad_conf['frame_duration_ms']) + vad_conf = asr_engine.config.get('vad_conf', None) + if vad_conf: + vad = VADAudio( + aggressiveness=vad_conf['aggressiveness'], + rate=vad_conf['sample_rate'], + frame_duration_ms=vad_conf['frame_duration_ms']) try: while True: @@ -64,13 +67,21 @@ async def websocket_endpoint(websocket: WebSocket): if message['signal'] == 'start': resp = {"status": "ok", "signal": "server_ready"} # do something at begining here + # create the instance to process the audio + connection_handler = PaddleASRConnectionHanddler(asr_engine) await websocket.send_json(resp) elif message['signal'] == 'end': - engine_pool = get_engine_pool() - asr_engine = engine_pool['asr'] # reset single engine for an new connection - asr_engine.reset() - resp = {"status": "ok", "signal": "finished"} + connection_handler.decode(is_finished=True) + connection_handler.rescoring() + asr_results = connection_handler.get_result() + connection_handler.reset() + + resp = { + "status": "ok", + "signal": "finished", + 'asr_results': asr_results + } await websocket.send_json(resp) break else: @@ -79,21 +90,11 @@ async def websocket_endpoint(websocket: WebSocket): elif "bytes" in message: message = message["bytes"] - engine_pool = get_engine_pool() - asr_engine = engine_pool['asr'] - asr_results = "" - frames = chunk_buffer.frame_generator(message) - for frame in frames: - samples = np.frombuffer(frame.bytes, dtype=np.int16) - sample_rate = asr_engine.config.sample_rate - x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - sample_rate) - asr_engine.run(x_chunk, x_chunk_lens) - asr_results = asr_engine.postprocess() + connection_handler.extract_feat(message) + connection_handler.decode(is_finished=False) + asr_results = connection_handler.get_result() - asr_results = asr_engine.postprocess() resp = {'asr_results': asr_results} - await websocket.send_json(resp) except WebSocketDisconnect: pass diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index f1330d1d..98d9e637 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -63,7 +63,8 @@ include(libsndfile) # include(boost) # not work set(boost_SOURCE_DIR ${fc_patch}/boost-src) set(BOOST_ROOT ${boost_SOURCE_DIR}) -# #find_package(boost REQUIRED PATHS ${BOOST_ROOT}) +include_directories(${boost_SOURCE_DIR}) +link_directories(${boost_SOURCE_DIR}/stage/lib) # Eigen include(eigen) @@ -141,4 +142,4 @@ set(DEPS ${DEPS} set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx) add_subdirectory(speechx) -add_subdirectory(examples) \ No newline at end of file +add_subdirectory(examples) diff --git a/speechx/examples/ds2_ol/CMakeLists.txt b/speechx/examples/ds2_ol/CMakeLists.txt index 89cbd0ef..08c19484 100644 --- a/speechx/examples/ds2_ol/CMakeLists.txt +++ b/speechx/examples/ds2_ol/CMakeLists.txt @@ -2,4 +2,5 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) add_subdirectory(feat) add_subdirectory(nnet) -add_subdirectory(decoder) \ No newline at end of file +add_subdirectory(decoder) +add_subdirectory(websocket) diff --git a/speechx/examples/ds2_ol/aishell/path.sh b/speechx/examples/ds2_ol/aishell/path.sh index 0a300f36..520129ea 100644 --- a/speechx/examples/ds2_ol/aishell/path.sh +++ b/speechx/examples/ds2_ol/aishell/path.sh @@ -1,6 +1,6 @@ # This contains the locations of binarys build required for running the examples. -SPEECHX_ROOT=$PWD/../../../ +SPEECHX_ROOT=$PWD/../../.. SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_TOOLS=$SPEECHX_ROOT/tools @@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin export LC_AL=C -SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat -export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN \ No newline at end of file +SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat:$SPEECHX_EXAMPLES/ds2_ol/websocket +export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/aishell/run.sh b/speechx/examples/ds2_ol/aishell/run.sh index 6a59ca9b..49fa5bc3 100755 --- a/speechx/examples/ds2_ol/aishell/run.sh +++ b/speechx/examples/ds2_ol/aishell/run.sh @@ -86,7 +86,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then ctc-prefix-beam-search-decoder-ol \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ + --params_path=$model_dir/avg_1.jit.pdiparams \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --dict_file=$vocb_dir/vocab.txt \ --result_wspecifier=ark,t:$data/split${nj}/JOB/result @@ -101,7 +101,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ];then ctc-prefix-beam-search-decoder-ol \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ + --params_path=$model_dir/avg_1.jit.pdiparams \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --dict_file=$vocb_dir/vocab.txt \ --lm_path=$lm \ @@ -128,7 +128,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then wfst-decoder-ol \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ + --params_path=$model_dir/avg_1.jit.pdiparams \ --word_symbol_table=$graph_dir/words.txt \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --graph_path=$graph_dir/TLG.fst --max_active=7500 \ @@ -137,4 +137,4 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then cat $data/split${nj}/*/result_tlg > $exp/${label_file}_tlg utils/compute-wer.py --char=1 --v=1 $exp/${label_file}_tlg $text > $exp/${wer}_tlg -fi \ No newline at end of file +fi diff --git a/speechx/examples/ds2_ol/aishell/websocket_client.sh b/speechx/examples/ds2_ol/aishell/websocket_client.sh new file mode 100644 index 00000000..3c6b4e91 --- /dev/null +++ b/speechx/examples/ds2_ol/aishell/websocket_client.sh @@ -0,0 +1,37 @@ +#!/bin/bash +set +x +set -e + +. path.sh + +# 1. compile +if [ ! -d ${SPEECHX_EXAMPLES} ]; then + pushd ${SPEECHX_ROOT} + bash build.sh + popd +fi + +# input +mkdir -p data +data=$PWD/data +ckpt_dir=$data/model +model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ +vocb_dir=$ckpt_dir/data/lang_char +# output +aishell_wav_scp=aishell_test.scp +if [ ! -d $data/test ]; then + pushd $data + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip + unzip aishell_test.zip + popd + + realpath $data/test/*/*.wav > $data/wavlist + awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id + paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp +fi + +export GLOG_logtostderr=1 + +# websocket client +websocket_client_main \ + --wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.36 diff --git a/speechx/examples/ds2_ol/aishell/websocket_server.sh b/speechx/examples/ds2_ol/aishell/websocket_server.sh new file mode 100644 index 00000000..ea619d54 --- /dev/null +++ b/speechx/examples/ds2_ol/aishell/websocket_server.sh @@ -0,0 +1,66 @@ +#!/bin/bash +set +x +set -e + +. path.sh + + +# 1. compile +if [ ! -d ${SPEECHX_EXAMPLES} ]; then + pushd ${SPEECHX_ROOT} + bash build.sh + popd +fi + +# input +mkdir -p data +data=$PWD/data +ckpt_dir=$data/model +model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ +vocb_dir=$ckpt_dir/data/lang_char/ + +# output +aishell_wav_scp=aishell_test.scp +if [ ! -d $data/test ]; then + pushd $data + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip + unzip aishell_test.zip + popd + + realpath $data/test/*/*.wav > $data/wavlist + awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id + paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp +fi + + +if [ ! -d $ckpt_dir ]; then + mkdir -p $ckpt_dir + wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz + tar xzfv $ckpt_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir +fi + + +export GLOG_logtostderr=1 + +# 3. gen cmvn +cmvn=$PWD/cmvn.ark +cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn + +text=$data/test/text +graph_dir=./aishell_graph +if [ ! -d $graph_dir ]; then + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip + unzip aishell_graph.zip +fi + +# 5. test websocket server +websocket_server_main \ + --cmvn_file=$cmvn \ + --model_path=$model_dir/avg_1.jit.pdmodel \ + --streaming_chunk=0.1 \ + --convert2PCM32=true \ + --params_path=$model_dir/avg_1.jit.pdiparams \ + --word_symbol_table=$graph_dir/words.txt \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --graph_path=$graph_dir/TLG.fst --max_active=7500 \ + --acoustic_scale=1.2 diff --git a/speechx/examples/ds2_ol/decoder/CMakeLists.txt b/speechx/examples/ds2_ol/decoder/CMakeLists.txt index 6139ebfa..62dd6862 100644 --- a/speechx/examples/ds2_ol/decoder/CMakeLists.txt +++ b/speechx/examples/ds2_ol/decoder/CMakeLists.txt @@ -17,3 +17,6 @@ add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) +add_executable(recognizer_test_main ${CMAKE_CURRENT_SOURCE_DIR}/recognizer_test_main.cc) +target_include_directories(recognizer_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(recognizer_test_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS}) diff --git a/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc b/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc index 49d64b69..e145f6ee 100644 --- a/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc +++ b/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc @@ -34,12 +34,10 @@ DEFINE_int32(receptive_field_length, DEFINE_int32(downsampling_rate, 4, "two CNN(kernel=5) module downsampling rate."); -DEFINE_string( - model_input_names, - "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", - "model input names"); DEFINE_string(model_output_names, - "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", + "save_infer_model/scale_0.tmp_1,save_infer_model/" + "scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/" + "scale_3.tmp_1", "model output names"); DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); @@ -58,12 +56,11 @@ int main(int argc, char* argv[]) { kaldi::SequentialBaseFloatMatrixReader feature_reader( FLAGS_feature_rspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - - std::string model_graph = FLAGS_model_path; + std::string model_path = FLAGS_model_path; std::string model_params = FLAGS_param_path; std::string dict_file = FLAGS_dict_file; std::string lm_path = FLAGS_lm_path; - LOG(INFO) << "model path: " << model_graph; + LOG(INFO) << "model path: " << model_path; LOG(INFO) << "model param: " << model_params; LOG(INFO) << "dict path: " << dict_file; LOG(INFO) << "lm path: " << lm_path; @@ -76,10 +73,9 @@ int main(int argc, char* argv[]) { ppspeech::CTCBeamSearch decoder(opts); ppspeech::ModelOptions model_opts; - model_opts.model_path = model_graph; + model_opts.model_path = model_path; model_opts.params_path = model_params; model_opts.cache_shape = FLAGS_model_cache_names; - model_opts.input_names = FLAGS_model_input_names; model_opts.output_names = FLAGS_model_output_names; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); @@ -125,7 +121,6 @@ int main(int argc, char* argv[]) { if (feature_chunk_size < receptive_field_length) break; int32 start = chunk_idx * chunk_stride; - int32 end = start + chunk_size; for (int row_id = 0; row_id < chunk_size; ++row_id) { kaldi::SubVector tmp(feature, start); diff --git a/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc b/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc new file mode 100644 index 00000000..198a8ec2 --- /dev/null +++ b/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/recognizer.h" +#include "decoder/param.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/table-types.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure(); + ppspeech::Recognizer recognizer(resource); + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + int sample_rate = 16000; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + int32 num_done = 0, num_err = 0; + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + std::vector> feats; + int feature_rows = 0; + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + kaldi::Vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk(i) = waveform(sample_offset + i); + } + + recognizer.Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + recognizer.SetFinished(); + } + recognizer.Decode(); + + sample_offset += cur_chunk_size; + } + std::string result; + result = recognizer.GetFinalResult(); + recognizer.Reset(); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + KALDI_LOG << " the result of " << utt << " is empty"; + continue; + } + KALDI_LOG << " the result of " << utt << " is " << result; + result_writer.Write(utt, result); + ++num_done; + } +} \ No newline at end of file diff --git a/speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc b/speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc index b8385664..0a9cfb06 100644 --- a/speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc +++ b/speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc @@ -73,9 +73,9 @@ int main(int argc, char* argv[]) { LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path; LOG(INFO) << "Binary: " << FLAGS_binary; } catch (simdjson::simdjson_error& err) { - LOG(ERR) << err.what(); + LOG(ERROR) << err.what(); } return 0; -} \ No newline at end of file +} diff --git a/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc b/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc index 27ca6f9f..0d10bd30 100644 --- a/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc +++ b/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc @@ -32,7 +32,6 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); - int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -66,7 +65,8 @@ int main(int argc, char* argv[]) { std::unique_ptr cmvn( new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram))); - ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); + ppspeech::FeatureCacheOptions feat_cache_opts; + ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); LOG(INFO) << "feat dim: " << feature_cache.Dim(); int sample_rate = 16000; diff --git a/speechx/examples/ds2_ol/websocket/CMakeLists.txt b/speechx/examples/ds2_ol/websocket/CMakeLists.txt new file mode 100644 index 00000000..754b528e --- /dev/null +++ b/speechx/examples/ds2_ol/websocket/CMakeLists.txt @@ -0,0 +1,10 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc) +target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(websocket_server_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS}) + +add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc) +target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(websocket_client_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS}) + diff --git a/speechx/examples/ds2_ol/websocket/websocket_client_main.cc b/speechx/examples/ds2_ol/websocket/websocket_client_main.cc new file mode 100644 index 00000000..68ea898a --- /dev/null +++ b/speechx/examples/ds2_ol/websocket/websocket_client_main.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "websocket/websocket_client.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/kaldi-io.h" +#include "kaldi/util/table-types.h" + +DEFINE_string(host, "127.0.0.1", "host of websocket server"); +DEFINE_int32(port, 201314, "port of websocket server"); +DEFINE_string(wav_rspecifier, "", "test wav scp path"); +DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size"); + +using kaldi::int16; +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + ppspeech::WebSocketClient client(FLAGS_host, FLAGS_port); + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + + const int sample_rate = 16000; + const float streaming_chunk = FLAGS_streaming_chunk; + const int chunk_sample_size = streaming_chunk * sample_rate; + + for (; !wav_reader.Done(); wav_reader.Next()) { + client.SendStartSignal(); + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + CHECK_EQ(wave_data.SampFreq(), sample_rate); + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + const int tot_samples = waveform.Dim(); + int sample_offset = 0; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = static_cast(waveform(sample_offset + i)); + } + client.SendBinaryData(wav_chunk.data(), + wav_chunk.size() * sizeof(int16)); + + + sample_offset += cur_chunk_size; + LOG(INFO) << "Send " << cur_chunk_size << " samples"; + std::this_thread::sleep_for( + std::chrono::milliseconds(static_cast(1 * 1000))); + + if (cur_chunk_size < chunk_sample_size) { + client.SendEndSignal(); + } + } + + while (!client.Done()) { + } + std::string result = client.GetResult(); + LOG(INFO) << "utt: " << utt << " " << result; + + + client.Join(); + return 0; + } + return 0; +} diff --git a/speechx/examples/ds2_ol/websocket/websocket_server_main.cc b/speechx/examples/ds2_ol/websocket/websocket_server_main.cc new file mode 100644 index 00000000..43cbd6bb --- /dev/null +++ b/speechx/examples/ds2_ol/websocket/websocket_server_main.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "websocket/websocket_server.h" +#include "decoder/param.h" + +DEFINE_int32(port, 201314, "websocket listening port"); + +int main(int argc, char *argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure(); + + ppspeech::WebSocketServer server(FLAGS_port, resource); + LOG(INFO) << "Listening at port " << FLAGS_port; + server.Start(); + return 0; +} diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index 225abee7..b4da095d 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -30,4 +30,10 @@ include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/decoder ) -add_subdirectory(decoder) \ No newline at end of file +add_subdirectory(decoder) + +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/websocket +) +add_subdirectory(websocket) diff --git a/speechx/speechx/base/common.h b/speechx/speechx/base/common.h index 7502bc5e..a9303cbb 100644 --- a/speechx/speechx/base/common.h +++ b/speechx/speechx/base/common.h @@ -28,8 +28,10 @@ #include #include #include +#include #include #include +#include #include #include "base/basic_types.h" diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index ee0863fd..06bf4020 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -7,5 +7,6 @@ add_library(decoder STATIC ctc_decoders/path_trie.cpp ctc_decoders/scorer.cpp ctc_tlg_decoder.cc + recognizer.cc ) -target_link_libraries(decoder PUBLIC kenlm utils fst) +target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc index 5365e709..7b720e7b 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -33,7 +33,6 @@ void TLGDecoder::InitDecoder() { void TLGDecoder::AdvanceDecode( const std::shared_ptr& decodable) { while (!decodable->IsLastFrame(frame_decoded_size_)) { - LOG(INFO) << "num frame decode: " << frame_decoded_size_; AdvanceDecoding(decodable.get()); } } @@ -63,4 +62,4 @@ std::string TLGDecoder::GetFinalBestPath() { } return words; } -} \ No newline at end of file +} diff --git a/speechx/speechx/decoder/param.h b/speechx/speechx/decoder/param.h new file mode 100644 index 00000000..cd50ef53 --- /dev/null +++ b/speechx/speechx/decoder/param.h @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" + +#include "decoder/ctc_beam_search_decoder.h" +#include "decoder/ctc_tlg_decoder.h" +#include "frontend/audio/feature_pipeline.h" + +DEFINE_string(cmvn_file, "", "read cmvn"); +DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size"); +DEFINE_bool(convert2PCM32, true, "audio convert to pcm32"); +DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); +DEFINE_string(params_path, "avg_1.jit.pdiparams", "paddle nnet model param"); +DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); +DEFINE_string(graph_path, "TLG", "decoder graph"); +DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); +DEFINE_int32(max_active, 7500, "max active"); +DEFINE_double(beam, 15.0, "decoder beam"); +DEFINE_double(lattice_beam, 7.5, "decoder beam"); +DEFINE_int32(receptive_field_length, + 7, + "receptive field of two CNN(kernel=5) downsampling module."); +DEFINE_int32(downsampling_rate, + 4, + "two CNN(kernel=5) module downsampling rate."); +DEFINE_string(model_output_names, + "save_infer_model/scale_0.tmp_1,save_infer_model/" + "scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/" + "scale_3.tmp_1", + "model output names"); +DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); + +namespace ppspeech { +// todo refactor later +FeaturePipelineOptions InitFeaturePipelineOptions() { + FeaturePipelineOptions opts; + opts.cmvn_file = FLAGS_cmvn_file; + opts.linear_spectrogram_opts.streaming_chunk = FLAGS_streaming_chunk; + opts.convert2PCM32 = FLAGS_convert2PCM32; + kaldi::FrameExtractionOptions frame_opts; + frame_opts.frame_length_ms = 20; + frame_opts.frame_shift_ms = 10; + frame_opts.remove_dc_offset = false; + frame_opts.window_type = "hanning"; + frame_opts.preemph_coeff = 0.0; + frame_opts.dither = 0.0; + opts.linear_spectrogram_opts.frame_opts = frame_opts; + opts.feature_cache_opts.frame_chunk_size = FLAGS_receptive_field_length; + opts.feature_cache_opts.frame_chunk_stride = FLAGS_downsampling_rate; + return opts; +} + +ModelOptions InitModelOptions() { + ModelOptions model_opts; + model_opts.model_path = FLAGS_model_path; + model_opts.params_path = FLAGS_params_path; + model_opts.cache_shape = FLAGS_model_cache_names; + model_opts.output_names = FLAGS_model_output_names; + return model_opts; +} + +TLGDecoderOptions InitDecoderOptions() { + TLGDecoderOptions decoder_opts; + decoder_opts.word_symbol_table = FLAGS_word_symbol_table; + decoder_opts.fst_path = FLAGS_graph_path; + decoder_opts.opts.max_active = FLAGS_max_active; + decoder_opts.opts.beam = FLAGS_beam; + decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; + return decoder_opts; +} + +RecognizerResource InitRecognizerResoure() { + RecognizerResource resource; + resource.acoustic_scale = FLAGS_acoustic_scale; + resource.feature_pipeline_opts = InitFeaturePipelineOptions(); + resource.model_opts = InitModelOptions(); + resource.tlg_opts = InitDecoderOptions(); + return resource; +} +} \ No newline at end of file diff --git a/speechx/speechx/decoder/recognizer.cc b/speechx/speechx/decoder/recognizer.cc new file mode 100644 index 00000000..2c90ada9 --- /dev/null +++ b/speechx/speechx/decoder/recognizer.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/recognizer.h" + +namespace ppspeech { + +using kaldi::Vector; +using kaldi::VectorBase; +using kaldi::BaseFloat; +using std::vector; +using kaldi::SubVector; +using std::unique_ptr; + +Recognizer::Recognizer(const RecognizerResource& resource) { + // resource_ = resource; + const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; + feature_pipeline_.reset(new FeaturePipeline(feature_opts)); + std::shared_ptr nnet(new PaddleNnet(resource.model_opts)); + BaseFloat ac_scale = resource.acoustic_scale; + decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale)); + decoder_.reset(new TLGDecoder(resource.tlg_opts)); + input_finished_ = false; +} + +void Recognizer::Accept(const Vector& waves) { + feature_pipeline_->Accept(waves); +} + +void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); } + +std::string Recognizer::GetFinalResult() { + return decoder_->GetFinalBestPath(); +} + +void Recognizer::SetFinished() { + feature_pipeline_->SetFinished(); + input_finished_ = true; +} + +bool Recognizer::IsFinished() { return input_finished_; } + +void Recognizer::Reset() { + feature_pipeline_->Reset(); + decodable_->Reset(); + decoder_->Reset(); +} + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/recognizer.h b/speechx/speechx/decoder/recognizer.h new file mode 100644 index 00000000..9a7e7d11 --- /dev/null +++ b/speechx/speechx/decoder/recognizer.h @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor later (SGoat) + +#pragma once + +#include "decoder/ctc_beam_search_decoder.h" +#include "decoder/ctc_tlg_decoder.h" +#include "frontend/audio/feature_pipeline.h" +#include "nnet/decodable.h" +#include "nnet/paddle_nnet.h" + +namespace ppspeech { + +struct RecognizerResource { + FeaturePipelineOptions feature_pipeline_opts; + ModelOptions model_opts; + TLGDecoderOptions tlg_opts; + // CTCBeamSearchOptions beam_search_opts; + kaldi::BaseFloat acoustic_scale; + RecognizerResource() + : acoustic_scale(1.0), + feature_pipeline_opts(), + model_opts(), + tlg_opts() {} +}; + +class Recognizer { + public: + explicit Recognizer(const RecognizerResource& resouce); + void Accept(const kaldi::Vector& waves); + void Decode(); + std::string GetFinalResult(); + void SetFinished(); + bool IsFinished(); + void Reset(); + + private: + // std::shared_ptr resource_; + // RecognizerResource resource_; + std::shared_ptr feature_pipeline_; + std::shared_ptr decodable_; + std::unique_ptr decoder_; + bool input_finished_; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/CMakeLists.txt b/speechx/speechx/frontend/audio/CMakeLists.txt index 35243b6e..2d20edf7 100644 --- a/speechx/speechx/frontend/audio/CMakeLists.txt +++ b/speechx/speechx/frontend/audio/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(frontend STATIC linear_spectrogram.cc audio_cache.cc feature_cache.cc + feature_pipeline.cc ) -target_link_libraries(frontend PUBLIC kaldi-matrix) \ No newline at end of file +target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common) diff --git a/speechx/speechx/frontend/audio/audio_cache.cc b/speechx/speechx/frontend/audio/audio_cache.cc index 50aca4fb..e8af6668 100644 --- a/speechx/speechx/frontend/audio/audio_cache.cc +++ b/speechx/speechx/frontend/audio/audio_cache.cc @@ -41,7 +41,7 @@ void AudioCache::Accept(const VectorBase& waves) { ready_feed_condition_.wait(lock); } for (size_t idx = 0; idx < waves.Dim(); ++idx) { - int32 buffer_idx = (idx + offset_) % ring_buffer_.size(); + int32 buffer_idx = (idx + offset_ + size_) % ring_buffer_.size(); ring_buffer_[buffer_idx] = waves(idx); if (convert2PCM32_) ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx)); diff --git a/speechx/speechx/frontend/audio/audio_cache.h b/speechx/speechx/frontend/audio/audio_cache.h index adef1239..a681ef09 100644 --- a/speechx/speechx/frontend/audio/audio_cache.h +++ b/speechx/speechx/frontend/audio/audio_cache.h @@ -24,7 +24,7 @@ namespace ppspeech { class AudioCache : public FrontendInterface { public: explicit AudioCache(int buffer_size = 1000 * kint16max, - bool convert2PCM32 = false); + bool convert2PCM32 = true); virtual void Accept(const kaldi::VectorBase& waves); diff --git a/speechx/speechx/frontend/audio/feature_cache.cc b/speechx/speechx/frontend/audio/feature_cache.cc index 3f7f6502..b5768460 100644 --- a/speechx/speechx/frontend/audio/feature_cache.cc +++ b/speechx/speechx/frontend/audio/feature_cache.cc @@ -23,10 +23,13 @@ using std::vector; using kaldi::SubVector; using std::unique_ptr; -FeatureCache::FeatureCache(int max_size, +FeatureCache::FeatureCache(FeatureCacheOptions opts, unique_ptr base_extractor) { - max_size_ = max_size; + max_size_ = opts.max_size; + frame_chunk_stride_ = opts.frame_chunk_stride; + frame_chunk_size_ = opts.frame_chunk_size; base_extractor_ = std::move(base_extractor); + dim_ = base_extractor_->Dim(); } void FeatureCache::Accept(const kaldi::VectorBase& inputs) { @@ -44,13 +47,14 @@ bool FeatureCache::Read(kaldi::Vector* feats) { std::unique_lock lock(mutex_); while (cache_.empty() && base_extractor_->IsFinished() == false) { - ready_read_condition_.wait(lock); - BaseFloat elapsed = timer.Elapsed() * 1000; - // todo replace 1.0 with timeout_ - if (elapsed > 1.0) { + // todo refactor: wait + // ready_read_condition_.wait(lock); + int32 elapsed = static_cast(timer.Elapsed() * 1000); + // todo replace 1 with timeout_, 1 ms + if (elapsed > 1) { return false; } - usleep(1000); // sleep 1 ms + usleep(100); // sleep 0.1 ms } if (cache_.empty()) return false; feats->Resize(cache_.front().Dim()); @@ -63,25 +67,41 @@ bool FeatureCache::Read(kaldi::Vector* feats) { // read all data from base_feature_extractor_ into cache_ bool FeatureCache::Compute() { // compute and feed - Vector feature_chunk; - bool result = base_extractor_->Read(&feature_chunk); + Vector feature; + bool result = base_extractor_->Read(&feature); + if (result == false || feature.Dim() == 0) return false; + int32 joint_len = feature.Dim() + remained_feature_.Dim(); + int32 num_chunk = + ((joint_len / dim_) - frame_chunk_size_) / frame_chunk_stride_ + 1; - std::unique_lock lock(mutex_); - while (cache_.size() >= max_size_) { - ready_feed_condition_.wait(lock); - } + Vector joint_feature(joint_len); + joint_feature.Range(0, remained_feature_.Dim()) + .CopyFromVec(remained_feature_); + joint_feature.Range(remained_feature_.Dim(), feature.Dim()) + .CopyFromVec(feature); - // feed cache - if (feature_chunk.Dim() != 0) { + for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { + int32 start = chunk_idx * frame_chunk_stride_ * dim_; + Vector feature_chunk(frame_chunk_size_ * dim_); + SubVector tmp(joint_feature.Data() + start, + frame_chunk_size_ * dim_); + feature_chunk.CopyFromVec(tmp); + + std::unique_lock lock(mutex_); + while (cache_.size() >= max_size_) { + ready_feed_condition_.wait(lock); + } + + // feed cache cache_.push(feature_chunk); + ready_read_condition_.notify_one(); } - ready_read_condition_.notify_one(); + int32 remained_feature_len = + joint_len - num_chunk * frame_chunk_stride_ * dim_; + remained_feature_.Resize(remained_feature_len); + remained_feature_.CopyFromVec(joint_feature.Range( + frame_chunk_stride_ * num_chunk * dim_, remained_feature_len)); return result; } -void Reset() { - // std::lock_guard lock(mutex_); - return; -} - } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/feature_cache.h b/speechx/speechx/frontend/audio/feature_cache.h index 99961b5e..607f72c0 100644 --- a/speechx/speechx/frontend/audio/feature_cache.h +++ b/speechx/speechx/frontend/audio/feature_cache.h @@ -19,10 +19,18 @@ namespace ppspeech { +struct FeatureCacheOptions { + int32 max_size; + int32 frame_chunk_size; + int32 frame_chunk_stride; + FeatureCacheOptions() + : max_size(kint16max), frame_chunk_size(1), frame_chunk_stride(1) {} +}; + class FeatureCache : public FrontendInterface { public: explicit FeatureCache( - int32 max_size = kint16max, + FeatureCacheOptions opts, std::unique_ptr base_extractor = NULL); // Feed feats or waves @@ -32,12 +40,15 @@ class FeatureCache : public FrontendInterface { virtual bool Read(kaldi::Vector* feats); // feat dim - virtual size_t Dim() const { return base_extractor_->Dim(); } + virtual size_t Dim() const { return dim_; } virtual void SetFinished() { + // std::unique_lock lock(mutex_); base_extractor_->SetFinished(); + LOG(INFO) << "set finished"; // read the last chunk data Compute(); + // ready_feed_condition_.notify_one(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); } @@ -52,9 +63,13 @@ class FeatureCache : public FrontendInterface { private: bool Compute(); + int32 dim_; size_t max_size_; - std::unique_ptr base_extractor_; + int32 frame_chunk_size_; + int32 frame_chunk_stride_; + kaldi::Vector remained_feature_; + std::unique_ptr base_extractor_; std::mutex mutex_; std::queue> cache_; std::condition_variable ready_feed_condition_; diff --git a/speechx/speechx/frontend/audio/feature_pipeline.cc b/speechx/speechx/frontend/audio/feature_pipeline.cc new file mode 100644 index 00000000..86eca2e0 --- /dev/null +++ b/speechx/speechx/frontend/audio/feature_pipeline.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "frontend/audio/feature_pipeline.h" + +namespace ppspeech { + +using std::unique_ptr; + +FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) { + unique_ptr data_source( + new ppspeech::AudioCache(1000 * kint16max, opts.convert2PCM32)); + + unique_ptr linear_spectrogram( + new ppspeech::LinearSpectrogram(opts.linear_spectrogram_opts, + std::move(data_source))); + + unique_ptr cmvn( + new ppspeech::CMVN(opts.cmvn_file, std::move(linear_spectrogram))); + + base_extractor_.reset( + new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn))); +} + +} // ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/feature_pipeline.h b/speechx/speechx/frontend/audio/feature_pipeline.h new file mode 100644 index 00000000..7bd6c84f --- /dev/null +++ b/speechx/speechx/frontend/audio/feature_pipeline.h @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor later (SGoat) + +#pragma once + +#include "frontend/audio/audio_cache.h" +#include "frontend/audio/data_cache.h" +#include "frontend/audio/feature_cache.h" +#include "frontend/audio/frontend_itf.h" +#include "frontend/audio/linear_spectrogram.h" +#include "frontend/audio/normalizer.h" + +namespace ppspeech { + +struct FeaturePipelineOptions { + std::string cmvn_file; + bool convert2PCM32; + LinearSpectrogramOptions linear_spectrogram_opts; + FeatureCacheOptions feature_cache_opts; + FeaturePipelineOptions() + : cmvn_file(""), + convert2PCM32(false), + linear_spectrogram_opts(), + feature_cache_opts() {} +}; + +class FeaturePipeline : public FrontendInterface { + public: + explicit FeaturePipeline(const FeaturePipelineOptions& opts); + virtual void Accept(const kaldi::VectorBase& waves) { + base_extractor_->Accept(waves); + } + virtual bool Read(kaldi::Vector* feats) { + return base_extractor_->Read(feats); + } + virtual size_t Dim() const { return base_extractor_->Dim(); } + virtual void SetFinished() { base_extractor_->SetFinished(); } + virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + virtual void Reset() { base_extractor_->Reset(); } + + private: + std::unique_ptr base_extractor_; +}; +} \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.cc b/speechx/speechx/frontend/audio/linear_spectrogram.cc index d6ae3d01..9ef5e766 100644 --- a/speechx/speechx/frontend/audio/linear_spectrogram.cc +++ b/speechx/speechx/frontend/audio/linear_spectrogram.cc @@ -52,16 +52,16 @@ bool LinearSpectrogram::Read(Vector* feats) { if (flag == false || input_feats.Dim() == 0) return false; int32 feat_len = input_feats.Dim(); - int32 left_len = reminded_wav_.Dim(); + int32 left_len = remained_wav_.Dim(); Vector waves(feat_len + left_len); - waves.Range(0, left_len).CopyFromVec(reminded_wav_); + waves.Range(0, left_len).CopyFromVec(remained_wav_); waves.Range(left_len, feat_len).CopyFromVec(input_feats); Compute(waves, feats); int32 frame_shift = opts_.frame_opts.WindowShift(); int32 num_frames = kaldi::NumFrames(waves.Dim(), opts_.frame_opts); int32 left_samples = waves.Dim() - frame_shift * num_frames; - reminded_wav_.Resize(left_samples); - reminded_wav_.CopyFromVec( + remained_wav_.Resize(left_samples); + remained_wav_.CopyFromVec( waves.Range(frame_shift * num_frames, left_samples)); return true; } diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.h b/speechx/speechx/frontend/audio/linear_spectrogram.h index 689ec2c4..2764b7cf 100644 --- a/speechx/speechx/frontend/audio/linear_spectrogram.h +++ b/speechx/speechx/frontend/audio/linear_spectrogram.h @@ -25,12 +25,12 @@ struct LinearSpectrogramOptions { kaldi::FrameExtractionOptions frame_opts; kaldi::BaseFloat streaming_chunk; // second - LinearSpectrogramOptions() : streaming_chunk(0.36), frame_opts() {} + LinearSpectrogramOptions() : streaming_chunk(0.1), frame_opts() {} void Register(kaldi::OptionsItf* opts) { opts->Register("streaming-chunk", &streaming_chunk, - "streaming chunk size, default: 0.36 sec"); + "streaming chunk size, default: 0.1 sec"); frame_opts.Register(opts); } }; @@ -48,7 +48,7 @@ class LinearSpectrogram : public FrontendInterface { virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual void Reset() { base_extractor_->Reset(); - reminded_wav_.Resize(0); + remained_wav_.Resize(0); } private: @@ -60,7 +60,7 @@ class LinearSpectrogram : public FrontendInterface { kaldi::BaseFloat hanning_window_energy_; LinearSpectrogramOptions opts_; std::unique_ptr base_extractor_; - kaldi::Vector reminded_wav_; + kaldi::Vector remained_wav_; int chunk_sample_size_; DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); }; diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index 3f5dadd2..465f64a9 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -78,7 +78,6 @@ bool Decodable::AdvanceChunk() { } int32 nnet_dim = 0; Vector inferences; - Matrix nnet_cache_tmp; nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim); nnet_cache_.CopyRowsFromVec(inferences); diff --git a/tests/unit/cli/test_cli.sh b/tests/unit/cli/test_cli.sh index 96ab84d6..87c24b09 100755 --- a/tests/unit/cli/test_cli.sh +++ b/tests/unit/cli/test_cli.sh @@ -1,5 +1,6 @@ #!/bin/bash set -e + # Audio classification wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/dog.wav paddlespeech cls --input ./cat.wav --topk 10 @@ -28,26 +29,16 @@ paddlespeech tts --am tacotron2_csmsc --input "你好,欢迎使用百度飞桨 paddlespeech tts --am tacotron2_csmsc --voc wavernn_csmsc --input "你好,欢迎使用百度飞桨深度学习框架!" paddlespeech tts --am tacotron2_ljspeech --voc pwgan_ljspeech --lang en --input "Life was like a box of chocolates, you never know what you're gonna get." - # Speech Translation (only support linux) paddlespeech st --input ./en.wav - -# batch process -echo -e "1 欢迎光临。\n2 谢谢惠顾。" | paddlespeech tts - -# shell pipeline -paddlespeech asr --input ./zh.wav | paddlespeech text --task punc - -# stats -paddlespeech stats --task asr -paddlespeech stats --task tts -paddlespeech stats --task cls - # Speaker Verification wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav paddlespeech vector --task spk --input 85236145389.wav +# batch process +echo -e "1 欢迎光临。\n2 谢谢惠顾。" | paddlespeech tts + echo -e "demo1 85236145389.wav \n demo2 85236145389.wav" > vec.job paddlespeech vector --task spk --input vec.job @@ -55,4 +46,13 @@ echo -e "demo3 85236145389.wav \n demo4 85236145389.wav" | paddlespeech vector - rm 85236145389.wav rm vec.job +# shell pipeline +paddlespeech asr --input ./zh.wav | paddlespeech text --task punc +# stats +paddlespeech stats --task asr +paddlespeech stats --task tts +paddlespeech stats --task cls +paddlespeech stats --task text +paddlespeech stats --task vector +paddlespeech stats --task st