Merge branch 'PaddlePaddle:develop' into cluster

pull/1722/head
qingen 2 years ago committed by GitHub
commit 26d5dded7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,9 +29,10 @@ from ..download import get_path_from_url
from ..executor import BaseExecutor
from ..log import logger
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
@ -39,94 +40,13 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor']
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"conformer_wenetspeech-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz',
'md5':
'76cb19ed857e6623856b7cd7ebbfeda4',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/conformer/checkpoints/wenetspeech',
},
"transformer_librispeech-en-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'2c667da24922aad391eacafe37bc1660',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/transformer/checkpoints/avg_10',
},
"deepspeech2offline_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz',
'md5':
'23e16c69730a1cb5d735c98c83c21e16',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2offline_librispeech-en-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'f5666c81ad015c8de03aac2bc92e5762',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm',
'lm_md5':
'099a601759d467cd0a8523ff939819c5'
},
}
model_alias = {
"deepspeech2offline":
"paddlespeech.s2t.models.ds2:DeepSpeech2Model",
"deepspeech2online":
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer":
"paddlespeech.s2t.models.u2:U2Model",
"transformer":
"paddlespeech.s2t.models.u2:U2Model",
"wenetspeech":
"paddlespeech.s2t.models.u2:U2Model",
}
@cli_register(
name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor):
def __init__(self):
super(ASRExecutor, self).__init__()
super().__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.parser = argparse.ArgumentParser(
prog='paddlespeech.asr', add_help=True)
@ -136,7 +56,9 @@ class ASRExecutor(BaseExecutor):
'--model',
type=str,
default='conformer_wenetspeech',
choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()],
choices=[
tag[:tag.index('-')] for tag in self.pretrained_models.keys()
],
help='Choose model type of asr task.')
self.parser.add_argument(
'--lang',
@ -192,23 +114,6 @@ class ASRExecutor(BaseExecutor):
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(self,
model_type: str='wenetspeech',
lang: str='zh',
@ -219,6 +124,7 @@ class ASRExecutor(BaseExecutor):
"""
Init model and other resources from a specific path.
"""
logger.info("start to init the model")
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
return
@ -228,19 +134,21 @@ class ASRExecutor(BaseExecutor):
tag = model_type + '-' + lang + '-' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.cfg_path = os.path.join(
res_path, self.pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(
res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams")
res_path,
self.pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
@ -255,8 +163,8 @@ class ASRExecutor(BaseExecutor):
self.collate_fn_test = SpeechCollator.from_config(self.config)
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, vocab=self.vocab)
lm_url = pretrained_models[tag]['lm_url']
lm_md5 = pretrained_models[tag]['lm_md5']
lm_url = self.pretrained_models[tag]['lm_url']
lm_md5 = self.pretrained_models[tag]['lm_md5']
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
@ -269,12 +177,11 @@ class ASRExecutor(BaseExecutor):
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.config.decode.decoding_method = decode_method
else:
raise Exception("wrong type")
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
model_class = dynamic_import(model_name, model_alias)
model_class = dynamic_import(model_name, self.model_alias)
model_conf = self.config
model = model_class.from_config(model_conf)
self.model = model
@ -347,12 +254,14 @@ class ASRExecutor(BaseExecutor):
else:
raise Exception("wrong type")
logger.info("audio feat process success")
@paddle.no_grad()
def infer(self, model_type: str):
"""
Model inference and result stored in self.output.
"""
logger.info("start to infer the model to get the output")
cfg = self.config.decode
audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"]
@ -369,17 +278,22 @@ class ASRExecutor(BaseExecutor):
self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type:
result_transcripts = self.model.decode(
audio,
audio_len,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
self._outputs["result"] = result_transcripts[0][0]
logger.info(f"we will use the transformer like model : {model_type}")
try:
result_transcripts = self.model.decode(
audio,
audio_len,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
self._outputs["result"] = result_transcripts[0][0]
except Exception as e:
logger.exception(e)
else:
raise Exception("invalid model name")

@ -0,0 +1,97 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"conformer_wenetspeech-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz',
'md5':
'76cb19ed857e6623856b7cd7ebbfeda4',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/conformer/checkpoints/wenetspeech',
},
"transformer_librispeech-en-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'2c667da24922aad391eacafe37bc1660',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/transformer/checkpoints/avg_10',
},
"deepspeech2offline_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz',
'md5':
'23e16c69730a1cb5d735c98c83c21e16',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2offline_librispeech-en-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'f5666c81ad015c8de03aac2bc92e5762',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm',
'lm_md5':
'099a601759d467cd0a8523ff939819c5'
},
}
model_alias = {
"deepspeech2offline":
"paddlespeech.s2t.models.ds2:DeepSpeech2Model",
"deepspeech2online":
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer":
"paddlespeech.s2t.models.u2:U2Model",
"conformer_online":
"paddlespeech.s2t.models.u2:U2Model",
"transformer":
"paddlespeech.s2t.models.u2:U2Model",
"wenetspeech":
"paddlespeech.s2t.models.u2:U2Model",
}

@ -25,55 +25,23 @@ import yaml
from ..executor import BaseExecutor
from ..log import logger
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
__all__ = ['CLSExecutor']
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"panns_cnn6-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz',
'md5': '4cf09194a95df024fd12f84712cf0f9c',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn6.pdparams',
'label_file': 'audioset_labels.txt',
},
"panns_cnn10-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz',
'md5': 'cb8427b22176cc2116367d14847f5413',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn10.pdparams',
'label_file': 'audioset_labels.txt',
},
"panns_cnn14-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz',
'md5': 'e3b9b5614a1595001161d0ab95edee97',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn14.pdparams',
'label_file': 'audioset_labels.txt',
},
}
model_alias = {
"panns_cnn6": "paddlespeech.cls.models.panns:CNN6",
"panns_cnn10": "paddlespeech.cls.models.panns:CNN10",
"panns_cnn14": "paddlespeech.cls.models.panns:CNN14",
}
@cli_register(
name='paddlespeech.cls', description='Audio classification infer command.')
class CLSExecutor(BaseExecutor):
def __init__(self):
super(CLSExecutor, self).__init__()
super().__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.parser = argparse.ArgumentParser(
prog='paddlespeech.cls', add_help=True)
@ -83,7 +51,9 @@ class CLSExecutor(BaseExecutor):
'--model',
type=str,
default='panns_cnn14',
choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()],
choices=[
tag[:tag.index('-')] for tag in self.pretrained_models.keys()
],
help='Choose model type of cls task.')
self.parser.add_argument(
'--config',
@ -121,23 +91,6 @@ class CLSExecutor(BaseExecutor):
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(self,
model_type: str='panns_cnn14',
cfg_path: Optional[os.PathLike]=None,
@ -153,12 +106,12 @@ class CLSExecutor(BaseExecutor):
if label_file is None or ckpt_path is None:
tag = model_type + '-' + '32k' # panns_cnn14-32k
self.res_path = self._get_pretrained_path(tag)
self.cfg_path = os.path.join(self.res_path,
pretrained_models[tag]['cfg_path'])
self.label_file = os.path.join(self.res_path,
pretrained_models[tag]['label_file'])
self.ckpt_path = os.path.join(self.res_path,
pretrained_models[tag]['ckpt_path'])
self.cfg_path = os.path.join(
self.res_path, self.pretrained_models[tag]['cfg_path'])
self.label_file = os.path.join(
self.res_path, self.pretrained_models[tag]['label_file'])
self.ckpt_path = os.path.join(
self.res_path, self.pretrained_models[tag]['ckpt_path'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.label_file = os.path.abspath(label_file)
@ -175,7 +128,7 @@ class CLSExecutor(BaseExecutor):
self._label_list.append(line.strip())
# model
model_class = dynamic_import(model_type, model_alias)
model_class = dynamic_import(model_type, self.model_alias)
model_dict = paddle.load(self.ckpt_path)
self.model = model_class(extract_embedding=False)
self.model.set_state_dict(model_dict)

@ -0,0 +1,47 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"panns_cnn6-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz',
'md5': '4cf09194a95df024fd12f84712cf0f9c',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn6.pdparams',
'label_file': 'audioset_labels.txt',
},
"panns_cnn10-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz',
'md5': 'cb8427b22176cc2116367d14847f5413',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn10.pdparams',
'label_file': 'audioset_labels.txt',
},
"panns_cnn14-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz',
'md5': 'e3b9b5614a1595001161d0ab95edee97',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn14.pdparams',
'label_file': 'audioset_labels.txt',
},
}
model_alias = {
"panns_cnn6": "paddlespeech.cls.models.panns:CNN6",
"panns_cnn10": "paddlespeech.cls.models.panns:CNN10",
"panns_cnn14": "paddlespeech.cls.models.panns:CNN14",
}

@ -25,6 +25,8 @@ from typing import Union
import paddle
from .log import logger
from .utils import download_and_decompress
from .utils import MODEL_HOME
class BaseExecutor(ABC):
@ -35,19 +37,8 @@ class BaseExecutor(ABC):
def __init__(self):
self._inputs = OrderedDict()
self._outputs = OrderedDict()
@abstractmethod
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
Args:
tag (str): A tag of pretrained model.
Returns:
os.PathLike: The path on which resources of pretrained model locate.
"""
pass
self.pretrained_models = OrderedDict()
self.model_alias = OrderedDict()
@abstractmethod
def _init_from_path(self, *args, **kwargs):
@ -227,3 +218,20 @@ class BaseExecutor(ABC):
]
for l in loggers:
l.disabled = True
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(self.pretrained_models.keys())
assert tag in self.pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(self.pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path

@ -32,40 +32,24 @@ from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from .pretrained_models import kaldi_bins
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ["STExecutor"]
pretrained_models = {
"fat_st_ted-en-zh": {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz",
"md5":
"d62063f35a16d91210a71081bd2dd557",
"cfg_path":
"model.yaml",
"ckpt_path":
"exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams",
}
}
model_alias = {"fat_st": "paddlespeech.s2t.models.u2_st:U2STModel"}
kaldi_bins = {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz",
"md5":
"c0682303b3f3393dbf6ed4c4e35a53eb",
}
@cli_register(
name="paddlespeech.st", description="Speech translation infer command.")
class STExecutor(BaseExecutor):
def __init__(self):
super(STExecutor, self).__init__()
super().__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.kaldi_bins = kaldi_bins
self.parser = argparse.ArgumentParser(
prog="paddlespeech.st", add_help=True)
@ -75,7 +59,9 @@ class STExecutor(BaseExecutor):
"--model",
type=str,
default="fat_st_ted",
choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()],
choices=[
tag[:tag.index('-')] for tag in self.pretrained_models.keys()
],
help="Choose model type of st task.")
self.parser.add_argument(
"--src_lang",
@ -119,28 +105,11 @@ class STExecutor(BaseExecutor):
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
"Use pretrained model stored in: {}".format(decompressed_path))
return decompressed_path
def _set_kaldi_bins(self) -> os.PathLike:
"""
Download and returns kaldi_bins resources path of current task.
"""
decompressed_path = download_and_decompress(kaldi_bins, MODEL_HOME)
decompressed_path = download_and_decompress(self.kaldi_bins, MODEL_HOME)
decompressed_path = os.path.abspath(decompressed_path)
logger.info("Kaldi_bins stored in: {}".format(decompressed_path))
if "LD_LIBRARY_PATH" in os.environ:
@ -197,7 +166,7 @@ class STExecutor(BaseExecutor):
model_conf = self.config
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
model_class = dynamic_import(model_name, model_alias)
model_class = dynamic_import(model_name, self.model_alias)
self.model = model_class.from_config(model_conf)
self.model.eval()

@ -0,0 +1,35 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
"fat_st_ted-en-zh": {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz",
"md5":
"d62063f35a16d91210a71081bd2dd557",
"cfg_path":
"model.yaml",
"ckpt_path":
"exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams",
}
}
model_alias = {"fat_st": "paddlespeech.s2t.models.u2_st:U2STModel"}
kaldi_bins = {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz",
"md5":
"c0682303b3f3393dbf6ed4c4e35a53eb",
}

@ -16,7 +16,6 @@ from typing import List
from prettytable import PrettyTable
from ..log import logger
from ..utils import cli_register
from ..utils import stats_wrapper
@ -27,7 +26,8 @@ model_name_format = {
'cls': 'Model-Sample Rate',
'st': 'Model-Source language-Target language',
'text': 'Model-Task-Language',
'tts': 'Model-Language'
'tts': 'Model-Language',
'vector': 'Model-Sample Rate'
}
@ -36,18 +36,18 @@ model_name_format = {
description='Get speech tasks support models list.')
class StatsExecutor():
def __init__(self):
super(StatsExecutor, self).__init__()
super().__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech.stats', add_help=True)
self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector']
self.parser.add_argument(
'--task',
type=str,
default='asr',
choices=['asr', 'cls', 'st', 'text', 'tts'],
choices=self.task_choices,
help='Choose speech task.',
required=True)
self.task_choices = ['asr', 'cls', 'st', 'text', 'tts']
def show_support_models(self, pretrained_models: dict):
fields = model_name_format[self.task].split("-")
@ -61,73 +61,15 @@ class StatsExecutor():
Command line entry.
"""
parser_args = self.parser.parse_args(argv)
self.task = parser_args.task
if self.task not in self.task_choices:
logger.error(
"Please input correct speech task, choices = ['asr', 'cls', 'st', 'text', 'tts']"
)
has_exceptions = False
try:
self(parser_args.task)
except Exception as e:
has_exceptions = True
if has_exceptions:
return False
elif self.task == 'asr':
try:
from ..asr.infer import pretrained_models
logger.info(
"Here is the list of ASR pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error("Failed to get the list of ASR pretrained models.")
return False
elif self.task == 'cls':
try:
from ..cls.infer import pretrained_models
logger.info(
"Here is the list of CLS pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error("Failed to get the list of CLS pretrained models.")
return False
elif self.task == 'st':
try:
from ..st.infer import pretrained_models
logger.info(
"Here is the list of ST pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error("Failed to get the list of ST pretrained models.")
return False
elif self.task == 'text':
try:
from ..text.infer import pretrained_models
logger.info(
"Here is the list of TEXT pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error(
"Failed to get the list of TEXT pretrained models.")
return False
elif self.task == 'tts':
try:
from ..tts.infer import pretrained_models
logger.info(
"Here is the list of TTS pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error("Failed to get the list of TTS pretrained models.")
return False
else:
return True
@stats_wrapper
def __call__(
@ -138,13 +80,12 @@ class StatsExecutor():
"""
self.task = task
if self.task not in self.task_choices:
print(
"Please input correct speech task, choices = ['asr', 'cls', 'st', 'text', 'tts']"
)
print("Please input correct speech task, choices = " + str(
self.task_choices))
elif self.task == 'asr':
try:
from ..asr.infer import pretrained_models
from ..asr.pretrained_models import pretrained_models
print(
"Here is the list of ASR pretrained models released by PaddleSpeech that can be used by command line and python API"
)
@ -154,7 +95,7 @@ class StatsExecutor():
elif self.task == 'cls':
try:
from ..cls.infer import pretrained_models
from ..cls.pretrained_models import pretrained_models
print(
"Here is the list of CLS pretrained models released by PaddleSpeech that can be used by command line and python API"
)
@ -164,7 +105,7 @@ class StatsExecutor():
elif self.task == 'st':
try:
from ..st.infer import pretrained_models
from ..st.pretrained_models import pretrained_models
print(
"Here is the list of ST pretrained models released by PaddleSpeech that can be used by command line and python API"
)
@ -174,7 +115,7 @@ class StatsExecutor():
elif self.task == 'text':
try:
from ..text.infer import pretrained_models
from ..text.pretrained_models import pretrained_models
print(
"Here is the list of TEXT pretrained models released by PaddleSpeech that can be used by command line and python API"
)
@ -184,10 +125,22 @@ class StatsExecutor():
elif self.task == 'tts':
try:
from ..tts.infer import pretrained_models
from ..tts.pretrained_models import pretrained_models
print(
"Here is the list of TTS pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of TTS pretrained models.")
elif self.task == 'vector':
try:
from ..vector.pretrained_models import pretrained_models
print(
"Here is the list of Speaker Recognition pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print(
"Failed to get the list of Speaker Recognition pretrained models."
)

@ -25,58 +25,21 @@ from ...s2t.utils.dynamic_import import dynamic_import
from ..executor import BaseExecutor
from ..log import logger
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from .pretrained_models import tokenizer_alias
__all__ = ['TextExecutor']
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"ernie_linear_p7_wudao-punc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz',
'md5':
'12283e2ddde1797c5d1e57036b512746',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
"ernie_linear_p3_wudao-punc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz',
'md5':
'448eb2fdf85b6a997e7e652e80c51dd2',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
}
model_alias = {
"ernie_linear_p7": "paddlespeech.text.models:ErnieLinear",
"ernie_linear_p3": "paddlespeech.text.models:ErnieLinear",
}
tokenizer_alias = {
"ernie_linear_p7": "paddlenlp.transformers:ErnieTokenizer",
"ernie_linear_p3": "paddlenlp.transformers:ErnieTokenizer",
}
@cli_register(name='paddlespeech.text', description='Text infer command.')
class TextExecutor(BaseExecutor):
def __init__(self):
super(TextExecutor, self).__init__()
super().__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.tokenizer_alias = tokenizer_alias
self.parser = argparse.ArgumentParser(
prog='paddlespeech.text', add_help=True)
@ -92,7 +55,9 @@ class TextExecutor(BaseExecutor):
'--model',
type=str,
default='ernie_linear_p7_wudao',
choices=[tag[:tag.index('-')] for tag in pretrained_models.keys()],
choices=[
tag[:tag.index('-')] for tag in self.pretrained_models.keys()
],
help='Choose model type of text task.')
self.parser.add_argument(
'--lang',
@ -131,23 +96,6 @@ class TextExecutor(BaseExecutor):
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(self,
task: str='punc',
model_type: str='ernie_linear_p7_wudao',
@ -167,12 +115,12 @@ class TextExecutor(BaseExecutor):
if cfg_path is None or ckpt_path is None or vocab_file is None:
tag = '-'.join([model_type, task, lang])
self.res_path = self._get_pretrained_path(tag)
self.cfg_path = os.path.join(self.res_path,
pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(self.res_path,
pretrained_models[tag]['ckpt_path'])
self.vocab_file = os.path.join(self.res_path,
pretrained_models[tag]['vocab_file'])
self.cfg_path = os.path.join(
self.res_path, self.pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(
self.res_path, self.pretrained_models[tag]['ckpt_path'])
self.vocab_file = os.path.join(
self.res_path, self.pretrained_models[tag]['vocab_file'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path)
@ -187,8 +135,8 @@ class TextExecutor(BaseExecutor):
self._punc_list.append(line.strip())
# model
model_class = dynamic_import(model_name, model_alias)
tokenizer_class = dynamic_import(model_name, tokenizer_alias)
model_class = dynamic_import(model_name, self.model_alias)
tokenizer_class = dynamic_import(model_name, self.tokenizer_alias)
self.model = model_class(
cfg_path=self.cfg_path, ckpt_path=self.ckpt_path)
self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0')

@ -0,0 +1,54 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"ernie_linear_p7_wudao-punc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz',
'md5':
'12283e2ddde1797c5d1e57036b512746',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
"ernie_linear_p3_wudao-punc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz',
'md5':
'448eb2fdf85b6a997e7e652e80c51dd2',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
}
model_alias = {
"ernie_linear_p7": "paddlespeech.text.models:ErnieLinear",
"ernie_linear_p3": "paddlespeech.text.models:ErnieLinear",
}
tokenizer_alias = {
"ernie_linear_p7": "paddlenlp.transformers:ErnieTokenizer",
"ernie_linear_p3": "paddlenlp.transformers:ErnieTokenizer",
}

@ -29,9 +29,9 @@ from yacs.config import CfgNode
from ..executor import BaseExecutor
from ..log import logger
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
@ -39,299 +39,14 @@ from paddlespeech.t2s.modules.normalizer import ZScore
__all__ = ['TTSExecutor']
pretrained_models = {
# speedyspeech
"speedyspeech_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip',
'md5':
'6f6fa967b408454b6662c8c00c0027cb',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_30600.pdz',
'speech_stats':
'feats_stats.npy',
'phones_dict':
'phone_id_map.txt',
'tones_dict':
'tone_id_map.txt',
},
# fastspeech2
"fastspeech2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip',
'md5':
'637d28a5e53aa60275612ba4393d5f22',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_76000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip',
'md5':
'ffed800c93deaf16ca9b3af89bfcd747',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_100000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip',
'md5':
'f4dd4a5f49a4552b77981f544ab3392e',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_96400.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
"fastspeech2_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip',
'md5':
'743e5024ca1e17a88c5c271db9779ba4',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_66200.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
# tacotron2
"tacotron2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip',
'md5':
'0df4b6f0bcbe0d73c5ed6df8867ab91a',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_30600.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"tacotron2_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip',
'md5':
'6a5eddd81ae0e81d16959b97481135f3',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_60300.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
# pwgan
"pwgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip',
'md5':
'2e481633325b5bdf0a3823c714d2c117',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
"pwgan_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip',
'md5':
'53610ba9708fd3008ccaf8e99dacbaf0',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
"pwgan_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip',
'md5':
'd7598fa41ad362d62f85ffc0f07e3d84',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
"pwgan_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip',
'md5':
'b3da1defcde3e578be71eb284cb89f2c',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'ee5f0604e20091f0d495b6ec4618b90d',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
# style_melgan
"style_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'5de2d5348f396de0c966926b8c462755',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
"hifigan_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip',
'md5':
'70e9131695decbca06a65fe51ed38a72',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
"hifigan_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip',
'md5':
'3bb49bc75032ed12f79c00c8cc79a09a',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
"hifigan_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip',
'md5':
'7da8f88359bca2457e705d924cf27bd4',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# wavernn
"wavernn_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip',
'md5':
'ee37b752f09bcba8f2af3b777ca38e13',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_400000.pdz',
'speech_stats':
'feats_stats.npy',
}
}
model_alias = {
# acoustic model
"speedyspeech":
"paddlespeech.t2s.models.speedyspeech:SpeedySpeech",
"speedyspeech_inference":
"paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference",
"fastspeech2":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2",
"fastspeech2_inference":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
"tacotron2":
"paddlespeech.t2s.models.tacotron2:Tacotron2",
"tacotron2_inference":
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
# voc
"pwgan":
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
"pwgan_inference":
"paddlespeech.t2s.models.parallel_wavegan:PWGInference",
"mb_melgan":
"paddlespeech.t2s.models.melgan:MelGANGenerator",
"mb_melgan_inference":
"paddlespeech.t2s.models.melgan:MelGANInference",
"style_melgan":
"paddlespeech.t2s.models.melgan:StyleMelGANGenerator",
"style_melgan_inference":
"paddlespeech.t2s.models.melgan:StyleMelGANInference",
"hifigan":
"paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
"hifigan_inference":
"paddlespeech.t2s.models.hifigan:HiFiGANInference",
"wavernn":
"paddlespeech.t2s.models.wavernn:WaveRNN",
"wavernn_inference":
"paddlespeech.t2s.models.wavernn:WaveRNNInference",
}
@cli_register(
name='paddlespeech.tts', description='Text to Speech infer command.')
class TTSExecutor(BaseExecutor):
def __init__(self):
super().__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.parser = argparse.ArgumentParser(
prog='paddlespeech.tts', add_help=True)
@ -449,22 +164,6 @@ class TTSExecutor(BaseExecutor):
action='store_true',
help='Increase logger verbosity of current task.')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(
self,
am: str='fastspeech2_csmsc',
@ -490,16 +189,15 @@ class TTSExecutor(BaseExecutor):
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag)
self.am_res_path = am_res_path
self.am_config = os.path.join(am_res_path,
pretrained_models[am_tag]['config'])
self.am_config = os.path.join(
am_res_path, self.pretrained_models[am_tag]['config'])
self.am_ckpt = os.path.join(am_res_path,
pretrained_models[am_tag]['ckpt'])
self.pretrained_models[am_tag]['ckpt'])
self.am_stat = os.path.join(
am_res_path, pretrained_models[am_tag]['speech_stats'])
am_res_path, self.pretrained_models[am_tag]['speech_stats'])
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
am_res_path, pretrained_models[am_tag]['phones_dict'])
print("self.phones_dict:", self.phones_dict)
am_res_path, self.pretrained_models[am_tag]['phones_dict'])
logger.info(am_res_path)
logger.info(self.am_config)
logger.info(self.am_ckpt)
@ -509,21 +207,20 @@ class TTSExecutor(BaseExecutor):
self.am_stat = os.path.abspath(am_stat)
self.phones_dict = os.path.abspath(phones_dict)
self.am_res_path = os.path.dirname(os.path.abspath(self.am_config))
print("self.phones_dict:", self.phones_dict)
# for speedyspeech
self.tones_dict = None
if 'tones_dict' in pretrained_models[am_tag]:
if 'tones_dict' in self.pretrained_models[am_tag]:
self.tones_dict = os.path.join(
am_res_path, pretrained_models[am_tag]['tones_dict'])
am_res_path, self.pretrained_models[am_tag]['tones_dict'])
if tones_dict:
self.tones_dict = tones_dict
# for multi speaker fastspeech2
self.speaker_dict = None
if 'speaker_dict' in pretrained_models[am_tag]:
if 'speaker_dict' in self.pretrained_models[am_tag]:
self.speaker_dict = os.path.join(
am_res_path, pretrained_models[am_tag]['speaker_dict'])
am_res_path, self.pretrained_models[am_tag]['speaker_dict'])
if speaker_dict:
self.speaker_dict = speaker_dict
@ -532,12 +229,12 @@ class TTSExecutor(BaseExecutor):
if voc_ckpt is None or voc_config is None or voc_stat is None:
voc_res_path = self._get_pretrained_path(voc_tag)
self.voc_res_path = voc_res_path
self.voc_config = os.path.join(voc_res_path,
pretrained_models[voc_tag]['config'])
self.voc_ckpt = os.path.join(voc_res_path,
pretrained_models[voc_tag]['ckpt'])
self.voc_config = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['config'])
self.voc_ckpt = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['ckpt'])
self.voc_stat = os.path.join(
voc_res_path, pretrained_models[voc_tag]['speech_stats'])
voc_res_path, self.pretrained_models[voc_tag]['speech_stats'])
logger.info(voc_res_path)
logger.info(self.voc_config)
logger.info(self.voc_ckpt)
@ -588,8 +285,9 @@ class TTSExecutor(BaseExecutor):
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias)
am_class = dynamic_import(am_name, self.model_alias)
am_inference_class = dynamic_import(am_name + '_inference',
self.model_alias)
if am_name == 'fastspeech2':
am = am_class(
@ -618,9 +316,9 @@ class TTSExecutor(BaseExecutor):
# vocoder
# model: {model_name}_{dataset}
voc_name = voc[:voc.rindex('_')]
voc_class = dynamic_import(voc_name, model_alias)
voc_class = dynamic_import(voc_name, self.model_alias)
voc_inference_class = dynamic_import(voc_name + '_inference',
model_alias)
self.model_alias)
if voc_name != 'wavernn':
voc = voc_class(**self.voc_config["generator_params"])
voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"])
@ -735,7 +433,6 @@ class TTSExecutor(BaseExecutor):
am_ckpt = args.am_ckpt
am_stat = args.am_stat
phones_dict = args.phones_dict
print("phones_dict:", phones_dict)
tones_dict = args.tones_dict
speaker_dict = args.speaker_dict
voc = args.voc

@ -0,0 +1,300 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# speedyspeech
"speedyspeech_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip',
'md5':
'6f6fa967b408454b6662c8c00c0027cb',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_30600.pdz',
'speech_stats':
'feats_stats.npy',
'phones_dict':
'phone_id_map.txt',
'tones_dict':
'tone_id_map.txt',
},
# fastspeech2
"fastspeech2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip',
'md5':
'637d28a5e53aa60275612ba4393d5f22',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_76000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip',
'md5':
'ffed800c93deaf16ca9b3af89bfcd747',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_100000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip',
'md5':
'f4dd4a5f49a4552b77981f544ab3392e',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_96400.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
"fastspeech2_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip',
'md5':
'743e5024ca1e17a88c5c271db9779ba4',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_66200.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
# tacotron2
"tacotron2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip',
'md5':
'0df4b6f0bcbe0d73c5ed6df8867ab91a',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_30600.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"tacotron2_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip',
'md5':
'6a5eddd81ae0e81d16959b97481135f3',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_60300.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
# pwgan
"pwgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip',
'md5':
'2e481633325b5bdf0a3823c714d2c117',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
"pwgan_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip',
'md5':
'53610ba9708fd3008ccaf8e99dacbaf0',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
"pwgan_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip',
'md5':
'd7598fa41ad362d62f85ffc0f07e3d84',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
"pwgan_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip',
'md5':
'b3da1defcde3e578be71eb284cb89f2c',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'ee5f0604e20091f0d495b6ec4618b90d',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
# style_melgan
"style_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'5de2d5348f396de0c966926b8c462755',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
"hifigan_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip',
'md5':
'70e9131695decbca06a65fe51ed38a72',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
"hifigan_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip',
'md5':
'3bb49bc75032ed12f79c00c8cc79a09a',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
"hifigan_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip',
'md5':
'7da8f88359bca2457e705d924cf27bd4',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# wavernn
"wavernn_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip',
'md5':
'ee37b752f09bcba8f2af3b777ca38e13',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_400000.pdz',
'speech_stats':
'feats_stats.npy',
}
}
model_alias = {
# acoustic model
"speedyspeech":
"paddlespeech.t2s.models.speedyspeech:SpeedySpeech",
"speedyspeech_inference":
"paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference",
"fastspeech2":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2",
"fastspeech2_inference":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
"tacotron2":
"paddlespeech.t2s.models.tacotron2:Tacotron2",
"tacotron2_inference":
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
# voc
"pwgan":
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
"pwgan_inference":
"paddlespeech.t2s.models.parallel_wavegan:PWGInference",
"mb_melgan":
"paddlespeech.t2s.models.melgan:MelGANGenerator",
"mb_melgan_inference":
"paddlespeech.t2s.models.melgan:MelGANInference",
"style_melgan":
"paddlespeech.t2s.models.melgan:StyleMelGANGenerator",
"style_melgan_inference":
"paddlespeech.t2s.models.melgan:StyleMelGANInference",
"hifigan":
"paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
"hifigan_inference":
"paddlespeech.t2s.models.hifigan:HiFiGANInference",
"wavernn":
"paddlespeech.t2s.models.wavernn:WaveRNN",
"wavernn_inference":
"paddlespeech.t2s.models.wavernn:WaveRNNInference",
}

@ -27,45 +27,24 @@ from yacs.config import CfgNode
from ..executor import BaseExecutor
from ..log import logger
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]".
# e.g. "ecapatdnn_voxceleb12-16k".
# Command line and python api use "{model_name}[-{dataset}]" as --model, usage:
# "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav"
"ecapatdnn_voxceleb12-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz',
'md5':
'cc33023c54ab346cd318408f43fcaf95',
'cfg_path':
'conf/model.yaml', # the yaml config path
'ckpt_path':
'model/model', # the format is ${dir}/{model_name},
# so the first 'model' is dir, the second 'model' is the name
# this means we have a model stored as model/model.pdparams
},
}
model_alias = {
"ecapatdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn",
}
@cli_register(
name="paddlespeech.vector",
description="Speech to vector embedding infer command.")
class VectorExecutor(BaseExecutor):
def __init__(self):
super(VectorExecutor, self).__init__()
super().__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.parser = argparse.ArgumentParser(
prog="paddlespeech.vector", add_help=True)
@ -128,8 +107,8 @@ class VectorExecutor(BaseExecutor):
Returns:
bool:
False: some audio occurs error
True: all audio process success
False: some audio occurs error
True: all audio process success
"""
# stage 0: parse the args and get the required args
parser_args = self.parser.parse_args(argv)
@ -289,32 +268,6 @@ class VectorExecutor(BaseExecutor):
return res
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""get the neural network path from the pretrained model list
we stored all the pretained mode in the variable `pretrained_models`
Args:
tag (str): model tag in the pretrained model list
Returns:
os.PathLike: the downloaded pretrained model path in the disk
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, \
'The model "{}" you want to use has not been supported,'\
'please choose other models.\n' \
'The support models includes\n\t\t{}'.format(tag, "\n\t\t".join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(self,
model_type: str='ecapatdnn_voxceleb12',
sample_rate: int=16000,
@ -350,10 +303,11 @@ class VectorExecutor(BaseExecutor):
res_path = self._get_pretrained_path(tag)
self.res_path = res_path
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path'])
self.cfg_path = os.path.join(
res_path, self.pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(
res_path, pretrained_models[tag]['ckpt_path'] + '.pdparams')
res_path,
self.pretrained_models[tag]['ckpt_path'] + '.pdparams')
else:
# get the model from disk
self.cfg_path = os.path.abspath(cfg_path)
@ -373,7 +327,7 @@ class VectorExecutor(BaseExecutor):
logger.info("start to dynamic import the model class")
model_name = model_type[:model_type.rindex('_')]
logger.info(f"model name {model_name}")
model_class = dynamic_import(model_name, model_alias)
model_class = dynamic_import(model_name, self.model_alias)
model_conf = self.config.model
backbone = model_class(**model_conf)
model = SpeakerIdetification(

@ -0,0 +1,36 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]".
# e.g. "ecapatdnn_voxceleb12-16k".
# Command line and python api use "{model_name}[-{dataset}]" as --model, usage:
# "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav"
"ecapatdnn_voxceleb12-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz',
'md5':
'cc33023c54ab346cd318408f43fcaf95',
'cfg_path':
'conf/model.yaml', # the yaml config path
'ckpt_path':
'model/model', # the format is ${dir}/{model_name},
# so the first 'model' is dir, the second 'model' is the name
# this means we have a model stored as model/model.pdparams
},
}
model_alias = {
"ecapatdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn",
}

@ -279,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
# logp: (B*N, vocab)
logp, cache = self.decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
@ -708,11 +707,11 @@ class U2BaseModel(ASRInterface, nn.Layer):
batch_size = feats.shape[0]
if decoding_method in ['ctc_prefix_beam_search',
'attention_rescoring'] and batch_size > 1:
logger.fatal(
logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1'
)
logger.error(f"current batch_size is {batch_size}")
sys.exit(1)
if decoding_method == 'attention':
hyps = self.recognize(
feats,

@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase):
# init once
if self._ext_scorer is not None:
return
if language_model_path != '':
logger.info("begin to initialize the external scorer "
"for decoding")

@ -35,3 +35,16 @@
```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
```
## Online ASR Server
### Lanuch online asr server
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### Access online asr server
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
```

@ -35,3 +35,17 @@
```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
```
## 流式ASR
### 启动流式语音识别服务
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### 访问流式语音识别服务
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav
```

@ -277,11 +277,12 @@ class ASRClientExecutor(BaseExecutor):
lang=lang,
audio_format=audio_format)
time_end = time.time()
logger.info(res.json())
logger.info(res)
logger.info("Response time %f s." % (time_end - time_start))
return True
except Exception as e:
logger.error("Failed to speech recognition.")
logger.error(e)
return False
@stats_wrapper
@ -299,9 +300,10 @@ class ASRClientExecutor(BaseExecutor):
logging.info("asr websocket client start")
handler = ASRAudioHandler(server_ip, port)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(input))
res = loop.run_until_complete(handler.run(input))
logging.info("asr websocket client finished")
return res['asr_results']
@cli_client_register(
name='paddlespeech_client.cls', description='visit cls service')

@ -41,11 +41,7 @@ asr_online:
shift_ms: 40
sample_rate: 16000
sample_width: 2
vad_conf:
aggressiveness: 2
sample_rate: 16000
frame_duration_ms: 20
sample_width: 2
padding_ms: 200
padding_ratio: 0.9
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 20 # ms
shift_ms: 10 # ms

@ -0,0 +1,45 @@
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'conformer_online_multicn'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2

File diff suppressed because it is too large Load Diff

@ -0,0 +1,128 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
from paddlespeech.cli.log import logger
from paddlespeech.s2t.utils.utility import log_add
__all__ = ['CTCPrefixBeamSearch']
class CTCPrefixBeamSearch:
def __init__(self, config):
"""Implement the ctc prefix beam search
Args:
config (yacs.config.CfgNode): _description_
"""
self.config = config
self.reset()
@paddle.no_grad()
def search(self, ctc_probs, device, blank_id=0):
"""ctc prefix beam search method decode a chunk feature
Args:
xs (paddle.Tensor): feature data
ctc_probs (paddle.Tensor): the ctc probability of all the tokens
device (paddle.fluid.core_avx.Place): the feature host device, such as CUDAPlace(0).
blank_id (int, optional): the blank id in the vocab. Defaults to 0.
Returns:
list: the search result
"""
# decode
logger.info("start to ctc prefix search")
batch_size = 1
beam_size = self.config.beam_size
maxlen = ctc_probs.shape[0]
assert len(ctc_probs.shape) == 2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if self.cur_hyps is None:
self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in self.cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
self.cur_hyps = next_hyps[:beam_size]
self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps]
logger.info("ctc prefix search success")
return self.hyps
def get_one_best_hyps(self):
"""Return the one best result
Returns:
list: the one best result
"""
return [self.hyps[0][0]]
def get_hyps(self):
"""Return the search hyps
Returns:
list: return the search hyps
"""
return self.hyps
def reset(self):
"""Rest the search cache value
"""
self.cur_hyps = None
self.hyps = None
def finalize_search(self):
"""do nothing in ctc_prefix_beam_search
"""
pass

@ -12,24 +12,329 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import math
import os
import time
from typing import Optional
import numpy as np
import paddle
import yaml
from yacs.config import CfgNode
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.util import denorm
from paddlespeech.server.utils.util import get_chunks
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore
__all__ = ['TTSEngine']
# support online model
pretrained_models = {
# fastspeech2
"fastspeech2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip',
'md5':
'637d28a5e53aa60275612ba4393d5f22',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_76000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_cnndecoder_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip',
'md5':
'6eb28e22ace73e0ebe7845f86478f89f',
'config':
'cnndecoder.yaml',
'ckpt':
'snapshot_iter_153000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'ee5f0604e20091f0d495b6ec4618b90d',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
}
model_alias = {
# acoustic model
"fastspeech2":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2",
"fastspeech2_inference":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
# voc
"mb_melgan":
"paddlespeech.t2s.models.melgan:MelGANGenerator",
"mb_melgan_inference":
"paddlespeech.t2s.models.melgan:MelGANInference",
"hifigan":
"paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
"hifigan_inference":
"paddlespeech.t2s.models.hifigan:HiFiGANInference",
}
__all__ = ['TTSEngine']
class TTSServerExecutor(TTSExecutor):
def __init__(self):
def __init__(self, am_block, am_pad, voc_block, voc_pad):
super().__init__()
pass
self.am_block = am_block
self.am_pad = am_pad
self.voc_block = voc_block
self.voc_pad = voc_pad
def get_model_info(self,
field: str,
model_name: str,
ckpt: Optional[os.PathLike],
stat: Optional[os.PathLike]):
"""get model information
Args:
field (str): am or voc
model_name (str): model type, support fastspeech2, higigan, mb_melgan
ckpt (Optional[os.PathLike]): ckpt file
stat (Optional[os.PathLike]): stat file, including mean and standard deviation
Returns:
[module]: model module
[Tensor]: mean
[Tensor]: standard deviation
"""
model_class = dynamic_import(model_name, model_alias)
if field == "am":
odim = self.am_config.n_mels
model = model_class(
idim=self.vocab_size, odim=odim, **self.am_config["model"])
model.set_state_dict(paddle.load(ckpt)["main_params"])
elif field == "voc":
model = model_class(**self.voc_config["generator_params"])
model.set_state_dict(paddle.load(ckpt)["generator_params"])
model.remove_weight_norm()
else:
logger.error("Please set correct field, am or voc")
model.eval()
model_mu, model_std = np.load(stat)
model_mu = paddle.to_tensor(model_mu)
model_std = paddle.to_tensor(model_std)
return model, model_mu, model_std
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(
self,
am: str='fastspeech2_csmsc',
am_config: Optional[os.PathLike]=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
voc: str='mb_melgan_csmsc',
voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None,
lang: str='zh', ):
"""
Init model and other resources from a specific path.
"""
if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'):
logger.info('Models had been initialized.')
return
# am model info
am_tag = am + '-' + lang
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag)
self.am_res_path = am_res_path
self.am_config = os.path.join(am_res_path,
pretrained_models[am_tag]['config'])
self.am_ckpt = os.path.join(am_res_path,
pretrained_models[am_tag]['ckpt'])
self.am_stat = os.path.join(
am_res_path, pretrained_models[am_tag]['speech_stats'])
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
am_res_path, pretrained_models[am_tag]['phones_dict'])
print("self.phones_dict:", self.phones_dict)
logger.info(am_res_path)
logger.info(self.am_config)
logger.info(self.am_ckpt)
else:
self.am_config = os.path.abspath(am_config)
self.am_ckpt = os.path.abspath(am_ckpt)
self.am_stat = os.path.abspath(am_stat)
self.phones_dict = os.path.abspath(phones_dict)
self.am_res_path = os.path.dirname(os.path.abspath(self.am_config))
print("self.phones_dict:", self.phones_dict)
self.tones_dict = None
self.speaker_dict = None
# voc model info
voc_tag = voc + '-' + lang
if voc_ckpt is None or voc_config is None or voc_stat is None:
voc_res_path = self._get_pretrained_path(voc_tag)
self.voc_res_path = voc_res_path
self.voc_config = os.path.join(voc_res_path,
pretrained_models[voc_tag]['config'])
self.voc_ckpt = os.path.join(voc_res_path,
pretrained_models[voc_tag]['ckpt'])
self.voc_stat = os.path.join(
voc_res_path, pretrained_models[voc_tag]['speech_stats'])
logger.info(voc_res_path)
logger.info(self.voc_config)
logger.info(self.voc_ckpt)
else:
self.voc_config = os.path.abspath(voc_config)
self.voc_ckpt = os.path.abspath(voc_ckpt)
self.voc_stat = os.path.abspath(voc_stat)
self.voc_res_path = os.path.dirname(
os.path.abspath(self.voc_config))
# Init body.
with open(self.am_config) as f:
self.am_config = CfgNode(yaml.safe_load(f))
with open(self.voc_config) as f:
self.voc_config = CfgNode(yaml.safe_load(f))
with open(self.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
self.vocab_size = len(phn_id)
print("vocab_size:", self.vocab_size)
# frontend
if lang == 'zh':
self.frontend = Frontend(
phone_vocab_path=self.phones_dict,
tone_vocab_path=self.tones_dict)
elif lang == 'en':
self.frontend = English(phone_vocab_path=self.phones_dict)
print("frontend done!")
# am infer info
self.am_name = am[:am.rindex('_')]
if self.am_name == "fastspeech2_cnndecoder":
self.am_inference, self.am_mu, self.am_std = self.get_model_info(
"am", "fastspeech2", self.am_ckpt, self.am_stat)
else:
am, am_mu, am_std = self.get_model_info("am", self.am_name,
self.am_ckpt, self.am_stat)
am_normalizer = ZScore(am_mu, am_std)
am_inference_class = dynamic_import(self.am_name + '_inference',
model_alias)
self.am_inference = am_inference_class(am_normalizer, am)
self.am_inference.eval()
print("acoustic model done!")
# voc infer info
self.voc_name = voc[:voc.rindex('_')]
voc, voc_mu, voc_std = self.get_model_info("voc", self.voc_name,
self.voc_ckpt, self.voc_stat)
voc_normalizer = ZScore(voc_mu, voc_std)
voc_inference_class = dynamic_import(self.voc_name + '_inference',
model_alias)
self.voc_inference = voc_inference_class(voc_normalizer, voc)
self.voc_inference.eval()
print("voc done!")
def get_phone(self, sentence, lang, merge_sentences, get_tone_ids):
tone_ids = None
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
def depadding(self, data, chunk_num, chunk_id, block, pad, upsample):
"""
Streaming inference removes the result of pad inference
"""
front_pad = min(chunk_id * block, pad)
# first chunk
if chunk_id == 0:
data = data[:block * upsample]
# last chunk
elif chunk_id == chunk_num - 1:
data = data[front_pad * upsample:]
# middle chunk
else:
data = data[front_pad * upsample:(front_pad + block) * upsample]
return data
@paddle.no_grad()
def infer(
@ -37,16 +342,20 @@ class TTSServerExecutor(TTSExecutor):
text: str,
lang: str='zh',
am: str='fastspeech2_csmsc',
spk_id: int=0,
am_block: int=42,
am_pad: int=12,
voc_block: int=14,
voc_pad: int=14, ):
spk_id: int=0, ):
"""
Model inference and result stored in self.output.
"""
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_block = self.am_block
am_pad = self.am_pad
am_upsample = 1
voc_block = self.voc_block
voc_pad = self.voc_pad
voc_upsample = self.voc_config.n_shift
# first_flag 用于标记首包
first_flag = 1
get_tone_ids = False
merge_sentences = False
frontend_st = time.time()
@ -64,43 +373,100 @@ class TTSServerExecutor(TTSExecutor):
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
self.frontend_time = time.time() - frontend_st
frontend_et = time.time()
self.frontend_time = frontend_et - frontend_st
for i in range(len(phone_ids)):
am_st = time.time()
part_phone_ids = phone_ids[i]
# am
if am_name == 'speedyspeech':
part_tone_ids = tone_ids[i]
mel = self.am_inference(part_phone_ids, part_tone_ids)
# fastspeech2
voc_chunk_id = 0
# fastspeech2_csmsc
if am == "fastspeech2_csmsc":
# am
mel = self.am_inference(part_phone_ids)
if first_flag == 1:
first_am_et = time.time()
self.first_am_infer = first_am_et - frontend_et
# voc streaming
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
voc_chunk_num = len(mel_chunks)
voc_st = time.time()
for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.voc_inference(mel_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num, i,
voc_block, voc_pad, voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
self.first_response_time = first_voc_et - frontend_st
first_flag = 0
yield sub_wav
# fastspeech2_cnndecoder_csmsc
elif am == "fastspeech2_cnndecoder_csmsc":
# am
orig_hs, h_masks = self.am_inference.encoder_infer(
part_phone_ids)
# streaming voc chunk info
mel_len = orig_hs.shape[1]
voc_chunk_num = math.ceil(mel_len / self.voc_block)
start = 0
end = min(self.voc_block + self.voc_pad, mel_len)
# streaming am
hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am")
am_chunk_num = len(hss)
for i, hs in enumerate(hss):
before_outs, _ = self.am_inference.decoder(hs)
after_outs = before_outs + self.am_inference.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
normalized_mel = after_outs[0]
sub_mel = denorm(normalized_mel, self.am_mu, self.am_std)
sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block,
am_pad, am_upsample)
if i == 0:
mel_streaming = sub_mel
else:
mel_streaming = np.concatenate(
(mel_streaming, sub_mel), axis=0)
# streaming voc
# 当流式AM推理的mel帧数大于流式voc推理的chunk size开始进行流式voc 推理
while (mel_streaming.shape[0] >= end and
voc_chunk_id < voc_chunk_num):
if first_flag == 1:
first_am_et = time.time()
self.first_am_infer = first_am_et - frontend_et
voc_chunk = mel_streaming[start:end, :]
voc_chunk = paddle.to_tensor(voc_chunk)
sub_wav = self.voc_inference(voc_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num,
voc_chunk_id, voc_block,
voc_pad, voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
self.first_response_time = first_voc_et - frontend_st
first_flag = 0
yield sub_wav
voc_chunk_id += 1
start = max(0, voc_chunk_id * voc_block - voc_pad)
end = min((voc_chunk_id + 1) * voc_block + voc_pad,
mel_len)
else:
# multi speaker
if am_dataset in {"aishell3", "vctk"}:
mel = self.am_inference(
part_phone_ids, spk_id=paddle.to_tensor(spk_id))
else:
mel = self.am_inference(part_phone_ids)
am_et = time.time()
# voc streaming
voc_upsample = self.voc_config.n_shift
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
chunk_num = len(mel_chunks)
voc_st = time.time()
for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.voc_inference(mel_chunk)
front_pad = min(i * voc_block, voc_pad)
if i == 0:
sub_wav = sub_wav[:voc_block * voc_upsample]
elif i == chunk_num - 1:
sub_wav = sub_wav[front_pad * voc_upsample:]
else:
sub_wav = sub_wav[front_pad * voc_upsample:(
front_pad + voc_block) * voc_upsample]
yield sub_wav
logger.error(
"Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts."
)
self.final_response_time = time.time() - frontend_st
class TTSEngine(BaseEngine):
@ -113,14 +479,21 @@ class TTSEngine(BaseEngine):
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super(TTSEngine, self).__init__()
super().__init__()
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
assert "fastspeech2_csmsc" in config.am and (
config.voc == "hifigan_csmsc-zh" or config.voc == "mb_melgan_csmsc"
assert (
config.am == "fastspeech2_csmsc" or
config.am == "fastspeech2_cnndecoder_csmsc"
) and (
config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
config.voc_block > 0 and config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
try:
if self.config.device:
self.device = self.config.device
@ -135,6 +508,9 @@ class TTSEngine(BaseEngine):
(self.device))
return False
self.executor = TTSServerExecutor(config.am_block, config.am_pad,
config.voc_block, config.voc_pad)
try:
self.executor._init_from_path(
am=self.config.am,
@ -155,15 +531,42 @@ class TTSEngine(BaseEngine):
(self.device))
return False
self.am_block = self.config.am_block
self.am_pad = self.config.am_pad
self.voc_block = self.config.voc_block
self.voc_pad = self.config.voc_pad
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
# warm up
try:
self.warm_up()
except Exception as e:
logger.error("Failed to warm up on tts engine.")
return False
return True
def warm_up(self):
"""warm up
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info(
"*******************************warm up ********************************"
)
for i in range(3):
for wav in self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, ):
logger.info(
f"The first response time of the {i} warm up: {self.executor.first_response_time} s"
)
break
logger.info(
"**********************************************************************"
)
def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
# Convert byte to text
if text_bese64:
@ -195,18 +598,14 @@ class TTSEngine(BaseEngine):
wav_base64: The base64 format of the synthesized audio.
"""
lang = self.config.lang
wav_list = []
for wav in self.executor.infer(
text=sentence,
lang=lang,
lang=self.config.lang,
am=self.config.am,
spk_id=spk_id,
am_block=self.am_block,
am_pad=self.am_pad,
voc_block=self.voc_block,
voc_pad=self.voc_pad):
spk_id=spk_id, ):
# wav type: <class 'numpy.ndarray'> float32, convert to pcm (base64)
wav = float2pcm(wav) # float32 to int16
wav_bytes = wav.tobytes() # to bytes
@ -216,5 +615,14 @@ class TTSEngine(BaseEngine):
yield wav_base64
wav_all = np.concatenate(wav_list, axis=0)
logger.info("The durations of audio is: {} s".format(
len(wav_all) / self.executor.am_config.fs))
duration = len(wav_all) / self.executor.am_config.fs
logger.info(f"sentence: {sentence}")
logger.info(f"The durations of audio is: {duration} s")
logger.info(
f"first response time: {self.executor.first_response_time} s")
logger.info(
f"final response time: {self.executor.final_response_time} s")
logger.info(f"RTF: {self.executor.final_response_time / duration}")
logger.info(
f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s,"
)

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -34,10 +34,9 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples)
# chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size != 0:
chunk_size = 85 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size!= 0:
padding_len_x = chunk_size - x_len % chunk_size
else:
padding_len_x = 0
@ -48,7 +47,6 @@ class ASRAudioHandler:
assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk)
for i in range(0, num_chunk):
start = i * chunk_size
end = start + chunk_size
@ -57,7 +55,11 @@ class ASRAudioHandler:
async def run(self, wavfile_path: str):
logging.info("send a message to the server")
# self.read_wave()
# send websocket handshake protocal
async with websockets.connect(self.url) as ws:
# server has already received handshake protocal
# client start to send the command
audio_info = json.dumps(
{
"name": "test.wav",
@ -78,7 +80,6 @@ class ASRAudioHandler:
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
result = msg
# finished
audio_info = json.dumps(
{
@ -91,10 +92,12 @@ class ASRAudioHandler:
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
# decode the bytes to str
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
return result
logging.info("final receive msg={}".format(msg))
result = msg
return result
def main(args):

@ -63,12 +63,12 @@ class ChunkBuffer(object):
the sample rate.
Yields Frames of the requested duration.
"""
audio = self.remained_audio + audio
self.remained_audio = b''
offset = 0
timestamp = 0.0
while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp,
self.window_sec)

@ -52,6 +52,10 @@ def get_chunks(data, block_size, pad_size, step):
Returns:
list: chunks list
"""
if block_size == -1:
return [data]
if step == "am":
data_len = data.shape[1]
elif step == "voc":

@ -13,12 +13,12 @@
# limitations under the License.
import json
import numpy as np
from fastapi import APIRouter
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio
@ -28,26 +28,29 @@ router = APIRouter()
@router.websocket('/ws/asr')
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
connection_handler = None
# init buffer
# each websocekt connection has its own chunk buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer(
window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width'])
window_n=chunk_buffer_conf.window_n,
shift_n=chunk_buffer_conf.shift_n,
window_ms=chunk_buffer_conf.window_ms,
shift_ms=chunk_buffer_conf.shift_ms,
sample_rate=chunk_buffer_conf.sample_rate,
sample_width=chunk_buffer_conf.sample_width)
# init vad
vad_conf = asr_engine.config.vad_conf
vad = VADAudio(
aggressiveness=vad_conf['aggressiveness'],
rate=vad_conf['sample_rate'],
frame_duration_ms=vad_conf['frame_duration_ms'])
vad_conf = asr_engine.config.get('vad_conf', None)
if vad_conf:
vad = VADAudio(
aggressiveness=vad_conf['aggressiveness'],
rate=vad_conf['sample_rate'],
frame_duration_ms=vad_conf['frame_duration_ms'])
try:
while True:
@ -64,13 +67,21 @@ async def websocket_endpoint(websocket: WebSocket):
if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here
# create the instance to process the audio
connection_handler = PaddleASRConnectionHanddler(asr_engine)
await websocket.send_json(resp)
elif message['signal'] == 'end':
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
# reset single engine for an new connection
asr_engine.reset()
resp = {"status": "ok", "signal": "finished"}
connection_handler.decode(is_finished=True)
connection_handler.rescoring()
asr_results = connection_handler.get_result()
connection_handler.reset()
resp = {
"status": "ok",
"signal": "finished",
'asr_results': asr_results
}
await websocket.send_json(resp)
break
else:
@ -79,21 +90,11 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message:
message = message["bytes"]
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_results = ""
frames = chunk_buffer.frame_generator(message)
for frame in frames:
samples = np.frombuffer(frame.bytes, dtype=np.int16)
sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate)
asr_engine.run(x_chunk, x_chunk_lens)
asr_results = asr_engine.postprocess()
connection_handler.extract_feat(message)
connection_handler.decode(is_finished=False)
asr_results = connection_handler.get_result()
asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results}
await websocket.send_json(resp)
except WebSocketDisconnect:
pass

@ -63,7 +63,8 @@ include(libsndfile)
# include(boost) # not work
set(boost_SOURCE_DIR ${fc_patch}/boost-src)
set(BOOST_ROOT ${boost_SOURCE_DIR})
# #find_package(boost REQUIRED PATHS ${BOOST_ROOT})
include_directories(${boost_SOURCE_DIR})
link_directories(${boost_SOURCE_DIR}/stage/lib)
# Eigen
include(eigen)
@ -141,4 +142,4 @@ set(DEPS ${DEPS}
set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx)
add_subdirectory(speechx)
add_subdirectory(examples)
add_subdirectory(examples)

@ -2,4 +2,5 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)
add_subdirectory(decoder)
add_subdirectory(websocket)

@ -1,6 +1,6 @@
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../../
SPEECHX_ROOT=$PWD/../../..
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat:$SPEECHX_EXAMPLES/ds2_ol/websocket
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN

@ -86,7 +86,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--params_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$vocb_dir/vocab.txt \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result
@ -101,7 +101,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ];then
ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--params_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm \
@ -128,7 +128,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
wfst-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--params_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \
@ -137,4 +137,4 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
cat $data/split${nj}/*/result_tlg > $exp/${label_file}_tlg
utils/compute-wer.py --char=1 --v=1 $exp/${label_file}_tlg $text > $exp/${wer}_tlg
fi
fi

@ -0,0 +1,37 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char
# output
aishell_wav_scp=aishell_test.scp
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
export GLOG_logtostderr=1
# websocket client
websocket_client_main \
--wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.36

@ -0,0 +1,66 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
# output
aishell_wav_scp=aishell_test.scp
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
if [ ! -d $ckpt_dir ]; then
mkdir -p $ckpt_dir
wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv $ckpt_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir
fi
export GLOG_logtostderr=1
# 3. gen cmvn
cmvn=$PWD/cmvn.ark
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
text=$data/test/text
graph_dir=./aishell_graph
if [ ! -d $graph_dir ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
unzip aishell_graph.zip
fi
# 5. test websocket server
websocket_server_main \
--cmvn_file=$cmvn \
--model_path=$model_dir/avg_1.jit.pdmodel \
--streaming_chunk=0.1 \
--convert2PCM32=true \
--params_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \
--acoustic_scale=1.2

@ -17,3 +17,6 @@ add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(recognizer_test_main ${CMAKE_CURRENT_SOURCE_DIR}/recognizer_test_main.cc)
target_include_directories(recognizer_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(recognizer_test_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})

@ -34,12 +34,10 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
@ -58,12 +56,11 @@ int main(int argc, char* argv[]) {
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path;
std::string model_path = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model path: " << model_path;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
@ -76,10 +73,9 @@ int main(int argc, char* argv[]) {
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.model_path = model_path;
model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
@ -125,7 +121,6 @@ int main(int argc, char* argv[]) {
if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride;
int32 end = start + chunk_size;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);

@ -0,0 +1,85 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/recognizer.h"
#include "decoder/param.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
ppspeech::Recognizer recognizer(resource);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int sample_rate = 16000;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
int32 num_done = 0, num_err = 0;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
recognizer.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer.SetFinished();
}
recognizer.Decode();
sample_offset += cur_chunk_size;
}
std::string result;
result = recognizer.GetFinalResult();
recognizer.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
}

@ -73,9 +73,9 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary;
} catch (simdjson::simdjson_error& err) {
LOG(ERR) << err.what();
LOG(ERROR) << err.what();
}
return 0;
}
}

@ -32,7 +32,6 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
@ -66,7 +65,8 @@ int main(int argc, char* argv[]) {
std::unique_ptr<ppspeech::FrontendInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram)));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
ppspeech::FeatureCacheOptions feat_cache_opts;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000;

@ -0,0 +1,10 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc)
target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_server_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS})
add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc)
target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_client_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS})

@ -0,0 +1,82 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "websocket/websocket_client.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(host, "127.0.0.1", "host of websocket server");
DEFINE_int32(port, 201314, "port of websocket server");
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size");
using kaldi::int16;
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::WebSocketClient client(FLAGS_host, FLAGS_port);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
const int sample_rate = 16000;
const float streaming_chunk = FLAGS_streaming_chunk;
const int chunk_sample_size = streaming_chunk * sample_rate;
for (; !wav_reader.Done(); wav_reader.Next()) {
client.SendStartSignal();
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
CHECK_EQ(wave_data.SampFreq(), sample_rate);
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
const int tot_samples = waveform.Dim();
int sample_offset = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<int16> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = static_cast<int16>(waveform(sample_offset + i));
}
client.SendBinaryData(wav_chunk.data(),
wav_chunk.size() * sizeof(int16));
sample_offset += cur_chunk_size;
LOG(INFO) << "Send " << cur_chunk_size << " samples";
std::this_thread::sleep_for(
std::chrono::milliseconds(static_cast<int>(1 * 1000)));
if (cur_chunk_size < chunk_sample_size) {
client.SendEndSignal();
}
}
while (!client.Done()) {
}
std::string result = client.GetResult();
LOG(INFO) << "utt: " << utt << " " << result;
client.Join();
return 0;
}
return 0;
}

@ -0,0 +1,30 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "websocket/websocket_server.h"
#include "decoder/param.h"
DEFINE_int32(port, 201314, "websocket listening port");
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
ppspeech::WebSocketServer server(FLAGS_port, resource);
LOG(INFO) << "Listening at port " << FLAGS_port;
server.Start();
return 0;
}

@ -30,4 +30,10 @@ include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/decoder
)
add_subdirectory(decoder)
add_subdirectory(decoder)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/websocket
)
add_subdirectory(websocket)

@ -28,8 +28,10 @@
#include <sstream>
#include <stack>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "base/basic_types.h"

@ -7,5 +7,6 @@ add_library(decoder STATIC
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
ctc_tlg_decoder.cc
recognizer.cc
)
target_link_libraries(decoder PUBLIC kenlm utils fst)
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder)

@ -33,7 +33,6 @@ void TLGDecoder::InitDecoder() {
void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (!decodable->IsLastFrame(frame_decoded_size_)) {
LOG(INFO) << "num frame decode: " << frame_decoded_size_;
AdvanceDecoding(decodable.get());
}
}
@ -63,4 +62,4 @@ std::string TLGDecoder::GetFinalBestPath() {
}
return words;
}
}
}

@ -0,0 +1,94 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
DEFINE_string(cmvn_file, "", "read cmvn");
DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size");
DEFINE_bool(convert2PCM32, true, "audio convert to pcm32");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(params_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
namespace ppspeech {
// todo refactor later
FeaturePipelineOptions InitFeaturePipelineOptions() {
FeaturePipelineOptions opts;
opts.cmvn_file = FLAGS_cmvn_file;
opts.linear_spectrogram_opts.streaming_chunk = FLAGS_streaming_chunk;
opts.convert2PCM32 = FLAGS_convert2PCM32;
kaldi::FrameExtractionOptions frame_opts;
frame_opts.frame_length_ms = 20;
frame_opts.frame_shift_ms = 10;
frame_opts.remove_dc_offset = false;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
frame_opts.dither = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
opts.feature_cache_opts.frame_chunk_size = FLAGS_receptive_field_length;
opts.feature_cache_opts.frame_chunk_stride = FLAGS_downsampling_rate;
return opts;
}
ModelOptions InitModelOptions() {
ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
model_opts.params_path = FLAGS_params_path;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.output_names = FLAGS_model_output_names;
return model_opts;
}
TLGDecoderOptions InitDecoderOptions() {
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
return decoder_opts;
}
RecognizerResource InitRecognizerResoure() {
RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = InitFeaturePipelineOptions();
resource.model_opts = InitModelOptions();
resource.tlg_opts = InitDecoderOptions();
return resource;
}
}

@ -0,0 +1,60 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/recognizer.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
Recognizer::Recognizer(const RecognizerResource& resource) {
// resource_ = resource;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(resource.model_opts));
BaseFloat ac_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale));
decoder_.reset(new TLGDecoder(resource.tlg_opts));
input_finished_ = false;
}
void Recognizer::Accept(const Vector<BaseFloat>& waves) {
feature_pipeline_->Accept(waves);
}
void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); }
std::string Recognizer::GetFinalResult() {
return decoder_->GetFinalBestPath();
}
void Recognizer::SetFinished() {
feature_pipeline_->SetFinished();
input_finished_ = true;
}
bool Recognizer::IsFinished() { return input_finished_; }
void Recognizer::Reset() {
feature_pipeline_->Reset();
decodable_->Reset();
decoder_->Reset();
}
} // namespace ppspeech

@ -0,0 +1,59 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor later (SGoat)
#pragma once
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
namespace ppspeech {
struct RecognizerResource {
FeaturePipelineOptions feature_pipeline_opts;
ModelOptions model_opts;
TLGDecoderOptions tlg_opts;
// CTCBeamSearchOptions beam_search_opts;
kaldi::BaseFloat acoustic_scale;
RecognizerResource()
: acoustic_scale(1.0),
feature_pipeline_opts(),
model_opts(),
tlg_opts() {}
};
class Recognizer {
public:
explicit Recognizer(const RecognizerResource& resouce);
void Accept(const kaldi::Vector<kaldi::BaseFloat>& waves);
void Decode();
std::string GetFinalResult();
void SetFinished();
bool IsFinished();
void Reset();
private:
// std::shared_ptr<RecognizerResource> resource_;
// RecognizerResource resource_;
std::shared_ptr<FeaturePipeline> feature_pipeline_;
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<TLGDecoder> decoder_;
bool input_finished_;
};
} // namespace ppspeech

@ -6,6 +6,7 @@ add_library(frontend STATIC
linear_spectrogram.cc
audio_cache.cc
feature_cache.cc
feature_pipeline.cc
)
target_link_libraries(frontend PUBLIC kaldi-matrix)
target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common)

@ -41,7 +41,7 @@ void AudioCache::Accept(const VectorBase<BaseFloat>& waves) {
ready_feed_condition_.wait(lock);
}
for (size_t idx = 0; idx < waves.Dim(); ++idx) {
int32 buffer_idx = (idx + offset_) % ring_buffer_.size();
int32 buffer_idx = (idx + offset_ + size_) % ring_buffer_.size();
ring_buffer_[buffer_idx] = waves(idx);
if (convert2PCM32_)
ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx));

@ -24,7 +24,7 @@ namespace ppspeech {
class AudioCache : public FrontendInterface {
public:
explicit AudioCache(int buffer_size = 1000 * kint16max,
bool convert2PCM32 = false);
bool convert2PCM32 = true);
virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves);

@ -23,10 +23,13 @@ using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
FeatureCache::FeatureCache(int max_size,
FeatureCache::FeatureCache(FeatureCacheOptions opts,
unique_ptr<FrontendInterface> base_extractor) {
max_size_ = max_size;
max_size_ = opts.max_size;
frame_chunk_stride_ = opts.frame_chunk_stride;
frame_chunk_size_ = opts.frame_chunk_size;
base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim();
}
void FeatureCache::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
@ -44,13 +47,14 @@ bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.empty() && base_extractor_->IsFinished() == false) {
ready_read_condition_.wait(lock);
BaseFloat elapsed = timer.Elapsed() * 1000;
// todo replace 1.0 with timeout_
if (elapsed > 1.0) {
// todo refactor: wait
// ready_read_condition_.wait(lock);
int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000);
// todo replace 1 with timeout_, 1 ms
if (elapsed > 1) {
return false;
}
usleep(1000); // sleep 1 ms
usleep(100); // sleep 0.1 ms
}
if (cache_.empty()) return false;
feats->Resize(cache_.front().Dim());
@ -63,25 +67,41 @@ bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
// read all data from base_feature_extractor_ into cache_
bool FeatureCache::Compute() {
// compute and feed
Vector<BaseFloat> feature_chunk;
bool result = base_extractor_->Read(&feature_chunk);
Vector<BaseFloat> feature;
bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) return false;
int32 joint_len = feature.Dim() + remained_feature_.Dim();
int32 num_chunk =
((joint_len / dim_) - frame_chunk_size_) / frame_chunk_stride_ + 1;
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.size() >= max_size_) {
ready_feed_condition_.wait(lock);
}
Vector<BaseFloat> joint_feature(joint_len);
joint_feature.Range(0, remained_feature_.Dim())
.CopyFromVec(remained_feature_);
joint_feature.Range(remained_feature_.Dim(), feature.Dim())
.CopyFromVec(feature);
// feed cache
if (feature_chunk.Dim() != 0) {
for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
int32 start = chunk_idx * frame_chunk_stride_ * dim_;
Vector<BaseFloat> feature_chunk(frame_chunk_size_ * dim_);
SubVector<BaseFloat> tmp(joint_feature.Data() + start,
frame_chunk_size_ * dim_);
feature_chunk.CopyFromVec(tmp);
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.size() >= max_size_) {
ready_feed_condition_.wait(lock);
}
// feed cache
cache_.push(feature_chunk);
ready_read_condition_.notify_one();
}
ready_read_condition_.notify_one();
int32 remained_feature_len =
joint_len - num_chunk * frame_chunk_stride_ * dim_;
remained_feature_.Resize(remained_feature_len);
remained_feature_.CopyFromVec(joint_feature.Range(
frame_chunk_stride_ * num_chunk * dim_, remained_feature_len));
return result;
}
void Reset() {
// std::lock_guard<std::mutex> lock(mutex_);
return;
}
} // namespace ppspeech

@ -19,10 +19,18 @@
namespace ppspeech {
struct FeatureCacheOptions {
int32 max_size;
int32 frame_chunk_size;
int32 frame_chunk_stride;
FeatureCacheOptions()
: max_size(kint16max), frame_chunk_size(1), frame_chunk_stride(1) {}
};
class FeatureCache : public FrontendInterface {
public:
explicit FeatureCache(
int32 max_size = kint16max,
FeatureCacheOptions opts,
std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves
@ -32,12 +40,15 @@ class FeatureCache : public FrontendInterface {
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// feat dim
virtual size_t Dim() const { return base_extractor_->Dim(); }
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() {
// std::unique_lock<std::mutex> lock(mutex_);
base_extractor_->SetFinished();
LOG(INFO) << "set finished";
// read the last chunk data
Compute();
// ready_feed_condition_.notify_one();
}
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
@ -52,9 +63,13 @@ class FeatureCache : public FrontendInterface {
private:
bool Compute();
int32 dim_;
size_t max_size_;
std::unique_ptr<FrontendInterface> base_extractor_;
int32 frame_chunk_size_;
int32 frame_chunk_stride_;
kaldi::Vector<kaldi::BaseFloat> remained_feature_;
std::unique_ptr<FrontendInterface> base_extractor_;
std::mutex mutex_;
std::queue<kaldi::Vector<BaseFloat>> cache_;
std::condition_variable ready_feed_condition_;

@ -0,0 +1,36 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/audio/feature_pipeline.h"
namespace ppspeech {
using std::unique_ptr;
FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
unique_ptr<FrontendInterface> data_source(
new ppspeech::AudioCache(1000 * kint16max, opts.convert2PCM32));
unique_ptr<FrontendInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opts.linear_spectrogram_opts,
std::move(data_source)));
unique_ptr<FrontendInterface> cmvn(
new ppspeech::CMVN(opts.cmvn_file, std::move(linear_spectrogram)));
base_extractor_.reset(
new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn)));
}
} // ppspeech

@ -0,0 +1,57 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor later (SGoat)
#pragma once
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
namespace ppspeech {
struct FeaturePipelineOptions {
std::string cmvn_file;
bool convert2PCM32;
LinearSpectrogramOptions linear_spectrogram_opts;
FeatureCacheOptions feature_cache_opts;
FeaturePipelineOptions()
: cmvn_file(""),
convert2PCM32(false),
linear_spectrogram_opts(),
feature_cache_opts() {}
};
class FeaturePipeline : public FrontendInterface {
public:
explicit FeaturePipeline(const FeaturePipelineOptions& opts);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves) {
base_extractor_->Accept(waves);
}
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
return base_extractor_->Read(feats);
}
virtual size_t Dim() const { return base_extractor_->Dim(); }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); }
private:
std::unique_ptr<FrontendInterface> base_extractor_;
};
}

@ -52,16 +52,16 @@ bool LinearSpectrogram::Read(Vector<BaseFloat>* feats) {
if (flag == false || input_feats.Dim() == 0) return false;
int32 feat_len = input_feats.Dim();
int32 left_len = reminded_wav_.Dim();
int32 left_len = remained_wav_.Dim();
Vector<BaseFloat> waves(feat_len + left_len);
waves.Range(0, left_len).CopyFromVec(reminded_wav_);
waves.Range(0, left_len).CopyFromVec(remained_wav_);
waves.Range(left_len, feat_len).CopyFromVec(input_feats);
Compute(waves, feats);
int32 frame_shift = opts_.frame_opts.WindowShift();
int32 num_frames = kaldi::NumFrames(waves.Dim(), opts_.frame_opts);
int32 left_samples = waves.Dim() - frame_shift * num_frames;
reminded_wav_.Resize(left_samples);
reminded_wav_.CopyFromVec(
remained_wav_.Resize(left_samples);
remained_wav_.CopyFromVec(
waves.Range(frame_shift * num_frames, left_samples));
return true;
}

@ -25,12 +25,12 @@ struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts;
kaldi::BaseFloat streaming_chunk; // second
LinearSpectrogramOptions() : streaming_chunk(0.36), frame_opts() {}
LinearSpectrogramOptions() : streaming_chunk(0.1), frame_opts() {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("streaming-chunk",
&streaming_chunk,
"streaming chunk size, default: 0.36 sec");
"streaming chunk size, default: 0.1 sec");
frame_opts.Register(opts);
}
};
@ -48,7 +48,7 @@ class LinearSpectrogram : public FrontendInterface {
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
reminded_wav_.Resize(0);
remained_wav_.Resize(0);
}
private:
@ -60,7 +60,7 @@ class LinearSpectrogram : public FrontendInterface {
kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_;
std::unique_ptr<FrontendInterface> base_extractor_;
kaldi::Vector<kaldi::BaseFloat> reminded_wav_;
kaldi::Vector<kaldi::BaseFloat> remained_wav_;
int chunk_sample_size_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
};

@ -78,7 +78,6 @@ bool Decodable::AdvanceChunk() {
}
int32 nnet_dim = 0;
Vector<BaseFloat> inferences;
Matrix<BaseFloat> nnet_cache_tmp;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim);
nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_.CopyRowsFromVec(inferences);

@ -1,5 +1,6 @@
#!/bin/bash
set -e
# Audio classification
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/cat.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/dog.wav
paddlespeech cls --input ./cat.wav --topk 10
@ -28,26 +29,16 @@ paddlespeech tts --am tacotron2_csmsc --input "你好,欢迎使用百度飞桨
paddlespeech tts --am tacotron2_csmsc --voc wavernn_csmsc --input "你好,欢迎使用百度飞桨深度学习框架!"
paddlespeech tts --am tacotron2_ljspeech --voc pwgan_ljspeech --lang en --input "Life was like a box of chocolates, you never know what you're gonna get."
# Speech Translation (only support linux)
paddlespeech st --input ./en.wav
# batch process
echo -e "1 欢迎光临。\n2 谢谢惠顾。" | paddlespeech tts
# shell pipeline
paddlespeech asr --input ./zh.wav | paddlespeech text --task punc
# stats
paddlespeech stats --task asr
paddlespeech stats --task tts
paddlespeech stats --task cls
# Speaker Verification
wget -c https://paddlespeech.bj.bcebos.com/vector/audio/85236145389.wav
paddlespeech vector --task spk --input 85236145389.wav
# batch process
echo -e "1 欢迎光临。\n2 谢谢惠顾。" | paddlespeech tts
echo -e "demo1 85236145389.wav \n demo2 85236145389.wav" > vec.job
paddlespeech vector --task spk --input vec.job
@ -55,4 +46,13 @@ echo -e "demo3 85236145389.wav \n demo4 85236145389.wav" | paddlespeech vector -
rm 85236145389.wav
rm vec.job
# shell pipeline
paddlespeech asr --input ./zh.wav | paddlespeech text --task punc
# stats
paddlespeech stats --task asr
paddlespeech stats --task tts
paddlespeech stats --task cls
paddlespeech stats --task text
paddlespeech stats --task vector
paddlespeech stats --task st

Loading…
Cancel
Save