Add paddlespeech.resource.

pull/1917/head
KP 2 years ago
parent 775c4befbd
commit fa6e44e4ff

@ -33,11 +33,8 @@ from ..utils import CLI_TIMER
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from ..utils import timer_register
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor']
@ -46,10 +43,7 @@ __all__ = ['ASRExecutor']
@timer_register
class ASRExecutor(BaseExecutor):
def __init__(self):
super().__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
super().__init__(task='asr', inference_type='offline')
self.parser = argparse.ArgumentParser(
prog='paddlespeech.asr', add_help=True)
self.parser.add_argument(
@ -59,7 +53,8 @@ class ASRExecutor(BaseExecutor):
type=str,
default='conformer_wenetspeech',
choices=[
tag[:tag.index('-')] for tag in self.pretrained_models.keys()
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
],
help='Choose model type of asr task.')
self.parser.add_argument(
@ -141,14 +136,14 @@ class ASRExecutor(BaseExecutor):
if cfg_path is None or ckpt_path is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.task_resource.set_task_model(tag, version=None)
self.res_path = self.task_resource.res_dir
self.cfg_path = os.path.join(
res_path, self.pretrained_models[tag]['cfg_path'])
self.res_path, self.task_resource.res_dict['cfg_path'])
self.ckpt_path = os.path.join(
res_path,
self.pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path)
self.res_path,
self.task_resource.res_dict['ckpt_path'] + ".pdparams")
logger.info(self.res_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
@ -172,8 +167,8 @@ class ASRExecutor(BaseExecutor):
self.collate_fn_test = SpeechCollator.from_config(self.config)
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, vocab=self.vocab)
lm_url = self.pretrained_models[tag]['lm_url']
lm_md5 = self.pretrained_models[tag]['lm_md5']
lm_url = self.resource.res_dict['lm_url']
lm_md5 = self.resource.res_dict['lm_md5']
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
@ -191,7 +186,7 @@ class ASRExecutor(BaseExecutor):
raise Exception("wrong type")
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
model_class = dynamic_import(model_name, self.model_alias)
model_class = self.task_resource.get_model_class(model_name)
model_conf = self.config
model = model_class.from_config(model_conf)
self.model = model
@ -438,7 +433,7 @@ class ASRExecutor(BaseExecutor):
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input)
task_source = self.get_input_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False

@ -1,151 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"conformer_wenetspeech-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz',
'md5':
'76cb19ed857e6623856b7cd7ebbfeda4',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/conformer/checkpoints/wenetspeech',
},
"conformer_online_wenetspeech-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz',
'md5':
'b8c02632b04da34aca88459835be54a6',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_10',
},
"conformer_online_multicn-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz',
'md5':
'7989b3248c898070904cf042fd656003',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/multi_cn',
},
"conformer_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz',
'md5':
'3f073eccfa7bb14e0c6867d65fc0dc3a',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/conformer/checkpoints/avg_30',
},
"conformer_online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz',
'md5':
'b374cfb93537761270b6224fb0bfc26a',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_30',
},
"transformer_librispeech-en-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'2c667da24922aad391eacafe37bc1660',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/transformer/checkpoints/avg_10',
},
"deepspeech2online_wenetspeech-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz',
'md5':
'e393d4d274af0f6967db24fc146e8074',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_10',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2offline_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz',
'md5':
'98b87b171b7240b7cae6e07d8d0bc9be',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"deepspeech2offline_librispeech-en-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'f5666c81ad015c8de03aac2bc92e5762',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm',
'lm_md5':
'099a601759d467cd0a8523ff939819c5'
},
}
model_alias = {
"deepspeech2offline":
"paddlespeech.s2t.models.ds2:DeepSpeech2Model",
"deepspeech2online":
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer":
"paddlespeech.s2t.models.u2:U2Model",
"conformer_online":
"paddlespeech.s2t.models.u2:U2Model",
"transformer":
"paddlespeech.s2t.models.u2:U2Model",
"wenetspeech":
"paddlespeech.s2t.models.u2:U2Model",
}

@ -11,17 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import List
from prettytable import PrettyTable
from ..resource import CommonTaskResource
from .entry import commands
from .utils import cli_register
from .utils import explicit_command_register
from .utils import get_command
__all__ = [
'BaseCommand',
'HelpCommand',
]
__all__ = ['BaseCommand', 'HelpCommand', 'StatsCommand']
@cli_register(name='paddlespeech')
@ -76,6 +77,59 @@ class VersionCommand:
return True
model_name_format = {
'asr': 'Model-Language-Sample Rate',
'cls': 'Model-Sample Rate',
'st': 'Model-Source language-Target language',
'text': 'Model-Task-Language',
'tts': 'Model-Language',
'vector': 'Model-Sample Rate'
}
@cli_register(
name='paddlespeech.stats',
description='Get speech tasks support models list.')
class StatsCommand:
def __init__(self):
self.parser = argparse.ArgumentParser(
prog='paddlespeech.stats', add_help=True)
self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector']
self.parser.add_argument(
'--task',
type=str,
default='asr',
choices=self.task_choices,
help='Choose speech task.',
required=True)
def show_support_models(self, pretrained_models: dict):
fields = model_name_format[self.task].split("-")
table = PrettyTable(fields)
for key in pretrained_models:
table.add_row(key.split("-"))
print(table)
def execute(self, argv: List[str]) -> bool:
parser_args = self.parser.parse_args(argv)
self.task = parser_args.task
if self.task not in self.task_choices:
print("Please input correct speech task, choices = " + str(
self.task_choices))
return
pretrained_models = CommonTaskResource(task=self.task).pretrained_models
try:
print(
"Here is the list of {} pretrained models released by PaddleSpeech that can be used by command line and python API"
.format(self.task.upper()))
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of {} pretrained models.".format(
self.task.upper()))
# Dynamic import when running specific command
_commands = {
'asr': ['Speech to text infer command.', 'ASRExecutor'],
@ -91,3 +145,4 @@ for com, info in _commands.items():
name='paddlespeech.{}'.format(com),
description=info[0],
cls='paddlespeech.cli.{}.{}'.format(com, info[1]))

@ -21,26 +21,19 @@ from typing import Union
import numpy as np
import paddle
import yaml
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from paddlespeech.utils.dynamic_import import dynamic_import
from ..executor import BaseExecutor
from ..log import logger
from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
__all__ = ['CLSExecutor']
class CLSExecutor(BaseExecutor):
def __init__(self):
super().__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
super().__init__(task='cls')
self.parser = argparse.ArgumentParser(
prog='paddlespeech.cls', add_help=True)
self.parser.add_argument(
@ -50,7 +43,8 @@ class CLSExecutor(BaseExecutor):
type=str,
default='panns_cnn14',
choices=[
tag[:tag.index('-')] for tag in self.pretrained_models.keys()
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
],
help='Choose model type of cls task.')
self.parser.add_argument(
@ -103,13 +97,16 @@ class CLSExecutor(BaseExecutor):
if label_file is None or ckpt_path is None:
tag = model_type + '-' + '32k' # panns_cnn14-32k
self.res_path = self._get_pretrained_path(tag)
self.task_resource.set_task_model(tag, version=None)
self.cfg_path = os.path.join(
self.res_path, self.pretrained_models[tag]['cfg_path'])
self.task_resource.res_dir,
self.task_resource.res_dict['cfg_path'])
self.label_file = os.path.join(
self.res_path, self.pretrained_models[tag]['label_file'])
self.task_resource.res_dir,
self.task_resource.res_dict['label_file'])
self.ckpt_path = os.path.join(
self.res_path, self.pretrained_models[tag]['ckpt_path'])
self.task_resource.res_dir,
self.task_resource.res_dict['ckpt_path'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.label_file = os.path.abspath(label_file)
@ -126,7 +123,7 @@ class CLSExecutor(BaseExecutor):
self._label_list.append(line.strip())
# model
model_class = dynamic_import(model_type, self.model_alias)
model_class = self.task_resource.get_model_class(model_type)
model_dict = paddle.load(self.ckpt_path)
self.model = model_class(extract_embedding=False)
self.model.set_state_dict(model_dict)
@ -203,7 +200,7 @@ class CLSExecutor(BaseExecutor):
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input)
task_source = self.get_input_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False

@ -1,47 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"panns_cnn6-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz',
'md5': '4cf09194a95df024fd12f84712cf0f9c',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn6.pdparams',
'label_file': 'audioset_labels.txt',
},
"panns_cnn10-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz',
'md5': 'cb8427b22176cc2116367d14847f5413',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn10.pdparams',
'label_file': 'audioset_labels.txt',
},
"panns_cnn14-32k": {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz',
'md5': 'e3b9b5614a1595001161d0ab95edee97',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn14.pdparams',
'label_file': 'audioset_labels.txt',
},
}
model_alias = {
"panns_cnn6": "paddlespeech.cls.models.panns:CNN6",
"panns_cnn10": "paddlespeech.cls.models.panns:CNN10",
"panns_cnn14": "paddlespeech.cls.models.panns:CNN14",
}

@ -24,9 +24,8 @@ from typing import Union
import paddle
from ..resource import CommonTaskResource
from .log import logger
from .utils import download_and_decompress
from .utils import MODEL_HOME
class BaseExecutor(ABC):
@ -34,11 +33,10 @@ class BaseExecutor(ABC):
An abstract executor of paddlespeech tasks.
"""
def __init__(self):
def __init__(self, task: str, **kwargs):
self._inputs = OrderedDict()
self._outputs = OrderedDict()
self.pretrained_models = OrderedDict()
self.model_alias = OrderedDict()
self.task_resource = CommonTaskResource(task=task, **kwargs)
@abstractmethod
def _init_from_path(self, *args, **kwargs):
@ -98,8 +96,8 @@ class BaseExecutor(ABC):
"""
pass
def get_task_source(self, input_: Union[str, os.PathLike, None]
) -> Dict[str, Union[str, os.PathLike]]:
def get_input_source(self, input_: Union[str, os.PathLike, None]
) -> Dict[str, Union[str, os.PathLike]]:
"""
Get task input source from command line input.
@ -115,15 +113,17 @@ class BaseExecutor(ABC):
ret = OrderedDict()
if input_ is None: # Take input from stdin
for i, line in enumerate(sys.stdin):
line = line.strip()
if len(line.split(' ')) == 1:
ret[str(i + 1)] = line
elif len(line.split(' ')) == 2:
id_, info = line.split(' ')
ret[id_] = info
else: # No valid input info from one line.
continue
if not sys.stdin.isatty(
): # Avoid getting stuck when stdin is empty.
for i, line in enumerate(sys.stdin):
line = line.strip()
if len(line.split(' ')) == 1:
ret[str(i + 1)] = line
elif len(line.split(' ')) == 2:
id_, info = line.split(' ')
ret[id_] = info
else: # No valid input info from one line.
continue
else:
ret[1] = input_
return ret
@ -219,23 +219,6 @@ class BaseExecutor(ABC):
for l in loggers:
l.disabled = True
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(self.pretrained_models.keys())
assert tag in self.pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(self.pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def show_rtf(self, info: Dict[str, List[float]]):
"""
Calculate rft of current task and show results.

@ -31,21 +31,22 @@ from ..log import logger
from ..utils import download_and_decompress
from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from .pretrained_models import kaldi_bins
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ["STExecutor"]
kaldi_bins = {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz",
"md5":
"c0682303b3f3393dbf6ed4c4e35a53eb",
}
class STExecutor(BaseExecutor):
def __init__(self):
super().__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
super().__init__(task='st')
self.kaldi_bins = kaldi_bins
self.parser = argparse.ArgumentParser(
@ -57,7 +58,8 @@ class STExecutor(BaseExecutor):
type=str,
default="fat_st_ted",
choices=[
tag[:tag.index('-')] for tag in self.pretrained_models.keys()
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
],
help="Choose model type of st task.")
self.parser.add_argument(
@ -131,14 +133,16 @@ class STExecutor(BaseExecutor):
if cfg_path is None or ckpt_path is None:
tag = model_type + "-" + src_lang + "-" + tgt_lang
res_path = self._get_pretrained_path(tag)
self.cfg_path = os.path.join(res_path,
pretrained_models[tag]["cfg_path"])
self.ckpt_path = os.path.join(res_path,
pretrained_models[tag]["ckpt_path"])
logger.info(res_path)
self.task_resource.set_task_model(tag, version=None)
self.cfg_path = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['cfg_path'])
self.ckpt_path = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['ckpt_path'])
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
res_path = self.task_resource.res_dir
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path)
@ -163,7 +167,7 @@ class STExecutor(BaseExecutor):
model_conf = self.config
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
model_class = dynamic_import(model_name, self.model_alias)
model_class = self.task_resource.get_model_class(model_name)
self.model = model_class.from_config(model_conf)
self.model.eval()
@ -301,7 +305,7 @@ class STExecutor(BaseExecutor):
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input)
task_source = self.get_input_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False

@ -1,35 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
"fat_st_ted-en-zh": {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz",
"md5":
"d62063f35a16d91210a71081bd2dd557",
"cfg_path":
"model.yaml",
"ckpt_path":
"exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams",
}
}
model_alias = {"fat_st": "paddlespeech.s2t.models.u2_st:U2STModel"}
kaldi_bins = {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz",
"md5":
"c0682303b3f3393dbf6ed4c4e35a53eb",
}

@ -1,146 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import List
from prettytable import PrettyTable
from ..utils import cli_register
from ..utils import stats_wrapper
__all__ = ['StatsExecutor']
model_name_format = {
'asr': 'Model-Language-Sample Rate',
'cls': 'Model-Sample Rate',
'st': 'Model-Source language-Target language',
'text': 'Model-Task-Language',
'tts': 'Model-Language',
'vector': 'Model-Sample Rate'
}
@cli_register(
name='paddlespeech.stats',
description='Get speech tasks support models list.')
class StatsExecutor():
def __init__(self):
super().__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech.stats', add_help=True)
self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector']
self.parser.add_argument(
'--task',
type=str,
default='asr',
choices=self.task_choices,
help='Choose speech task.',
required=True)
def show_support_models(self, pretrained_models: dict):
fields = model_name_format[self.task].split("-")
table = PrettyTable(fields)
for key in pretrained_models:
table.add_row(key.split("-"))
print(table)
def execute(self, argv: List[str]) -> bool:
"""
Command line entry.
"""
parser_args = self.parser.parse_args(argv)
has_exceptions = False
try:
self(parser_args.task)
except Exception as e:
has_exceptions = True
if has_exceptions:
return False
else:
return True
@stats_wrapper
def __call__(
self,
task: str=None, ):
"""
Python API to call an executor.
"""
self.task = task
if self.task not in self.task_choices:
print("Please input correct speech task, choices = " + str(
self.task_choices))
elif self.task == 'asr':
try:
from ..asr.pretrained_models import pretrained_models
print(
"Here is the list of ASR pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of ASR pretrained models.")
elif self.task == 'cls':
try:
from ..cls.pretrained_models import pretrained_models
print(
"Here is the list of CLS pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of CLS pretrained models.")
elif self.task == 'st':
try:
from ..st.pretrained_models import pretrained_models
print(
"Here is the list of ST pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of ST pretrained models.")
elif self.task == 'text':
try:
from ..text.pretrained_models import pretrained_models
print(
"Here is the list of TEXT pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of TEXT pretrained models.")
elif self.task == 'tts':
try:
from ..tts.pretrained_models import pretrained_models
print(
"Here is the list of TTS pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print("Failed to get the list of TTS pretrained models.")
elif self.task == 'vector':
try:
from ..vector.pretrained_models import pretrained_models
print(
"Here is the list of Speaker Recognition pretrained models released by PaddleSpeech that can be used by command line and python API"
)
self.show_support_models(pretrained_models)
except BaseException:
print(
"Failed to get the list of Speaker Recognition pretrained models."
)

@ -24,21 +24,13 @@ import paddle
from ..executor import BaseExecutor
from ..log import logger
from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from .pretrained_models import tokenizer_alias
from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ['TextExecutor']
class TextExecutor(BaseExecutor):
def __init__(self):
super().__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
self.tokenizer_alias = tokenizer_alias
super().__init__(task='text')
self.parser = argparse.ArgumentParser(
prog='paddlespeech.text', add_help=True)
self.parser.add_argument(
@ -54,7 +46,8 @@ class TextExecutor(BaseExecutor):
type=str,
default='ernie_linear_p7_wudao',
choices=[
tag[:tag.index('-')] for tag in self.pretrained_models.keys()
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
],
help='Choose model type of text task.')
self.parser.add_argument(
@ -112,13 +105,16 @@ class TextExecutor(BaseExecutor):
if cfg_path is None or ckpt_path is None or vocab_file is None:
tag = '-'.join([model_type, task, lang])
self.res_path = self._get_pretrained_path(tag)
self.task_resource.set_task_model(tag, version=None)
self.cfg_path = os.path.join(
self.res_path, self.pretrained_models[tag]['cfg_path'])
self.task_resource.res_dir,
self.task_resource.res_dict['cfg_path'])
self.ckpt_path = os.path.join(
self.res_path, self.pretrained_models[tag]['ckpt_path'])
self.task_resource.res_dir,
self.task_resource.res_dict['ckpt_path'])
self.vocab_file = os.path.join(
self.res_path, self.pretrained_models[tag]['vocab_file'])
self.task_resource.res_dir,
self.task_resource.res_dict['vocab_file'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path)
@ -133,8 +129,8 @@ class TextExecutor(BaseExecutor):
self._punc_list.append(line.strip())
# model
model_class = dynamic_import(model_name, self.model_alias)
tokenizer_class = dynamic_import(model_name, self.tokenizer_alias)
model_class, tokenizer_class = self.task_resource.get_model_class(
model_name)
self.model = model_class(
cfg_path=self.cfg_path, ckpt_path=self.ckpt_path)
self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0')
@ -224,7 +220,7 @@ class TextExecutor(BaseExecutor):
if not parser_args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(parser_args.input)
task_source = self.get_input_source(parser_args.input)
task_results = OrderedDict()
has_exceptions = False

@ -1,54 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
"ernie_linear_p7_wudao-punc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz',
'md5':
'12283e2ddde1797c5d1e57036b512746',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
"ernie_linear_p3_wudao-punc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz',
'md5':
'448eb2fdf85b6a997e7e652e80c51dd2',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
}
model_alias = {
"ernie_linear_p7": "paddlespeech.text.models:ErnieLinear",
"ernie_linear_p3": "paddlespeech.text.models:ErnieLinear",
}
tokenizer_alias = {
"ernie_linear_p7": "paddlenlp.transformers:ErnieTokenizer",
"ernie_linear_p3": "paddlenlp.transformers:ErnieTokenizer",
}

@ -29,22 +29,16 @@ from yacs.config import CfgNode
from ..executor import BaseExecutor
from ..log import logger
from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ['TTSExecutor']
class TTSExecutor(BaseExecutor):
def __init__(self):
super().__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
super().__init__('tts')
self.parser = argparse.ArgumentParser(
prog='paddlespeech.tts', add_help=True)
self.parser.add_argument(
@ -183,19 +177,23 @@ class TTSExecutor(BaseExecutor):
return
# am
am_tag = am + '-' + lang
self.task_resource.set_task_model(
model_tag=am_tag,
model_type=0, # am
version=None, # default version
)
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag)
self.am_res_path = am_res_path
self.am_config = os.path.join(
am_res_path, self.pretrained_models[am_tag]['config'])
self.am_ckpt = os.path.join(am_res_path,
self.pretrained_models[am_tag]['ckpt'])
self.am_res_path = self.task_resource.res_dir
self.am_config = os.path.join(self.am_res_path,
self.task_resource.res_dict['config'])
self.am_ckpt = os.path.join(self.am_res_path,
self.task_resource.res_dict['ckpt'])
self.am_stat = os.path.join(
am_res_path, self.pretrained_models[am_tag]['speech_stats'])
self.am_res_path, self.task_resource.res_dict['speech_stats'])
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['phones_dict'])
logger.info(am_res_path)
self.am_res_path, self.task_resource.res_dict['phones_dict'])
logger.info(self.am_res_path)
logger.info(self.am_config)
logger.info(self.am_ckpt)
else:
@ -207,33 +205,37 @@ class TTSExecutor(BaseExecutor):
# for speedyspeech
self.tones_dict = None
if 'tones_dict' in self.pretrained_models[am_tag]:
if 'tones_dict' in self.task_resource.res_dict:
self.tones_dict = os.path.join(
self.am_res_path, self.pretrained_models[am_tag]['tones_dict'])
self.am_res_path, self.task_resource.res_dict['tones_dict'])
if tones_dict:
self.tones_dict = tones_dict
# for multi speaker fastspeech2
self.speaker_dict = None
if 'speaker_dict' in self.pretrained_models[am_tag]:
if 'speaker_dict' in self.task_resource.res_dict:
self.speaker_dict = os.path.join(
self.am_res_path,
self.pretrained_models[am_tag]['speaker_dict'])
self.am_res_path, self.task_resource.res_dict['speaker_dict'])
if speaker_dict:
self.speaker_dict = speaker_dict
# voc
voc_tag = voc + '-' + lang
self.task_resource.set_task_model(
model_tag=voc_tag,
model_type=1, # vocoder
version=None, # default version
)
if voc_ckpt is None or voc_config is None or voc_stat is None:
voc_res_path = self._get_pretrained_path(voc_tag)
self.voc_res_path = voc_res_path
self.voc_res_path = self.task_resource.voc_res_dir
self.voc_config = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['config'])
self.voc_res_path, self.task_resource.voc_res_dict['config'])
self.voc_ckpt = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['ckpt'])
self.voc_res_path, self.task_resource.voc_res_dict['ckpt'])
self.voc_stat = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['speech_stats'])
logger.info(voc_res_path)
self.voc_res_path,
self.task_resource.voc_res_dict['speech_stats'])
logger.info(self.voc_res_path)
logger.info(self.voc_config)
logger.info(self.voc_ckpt)
else:
@ -283,9 +285,9 @@ class TTSExecutor(BaseExecutor):
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_class = dynamic_import(am_name, self.model_alias)
am_inference_class = dynamic_import(am_name + '_inference',
self.model_alias)
am_class = self.task_resource.get_model_class(am_name)
am_inference_class = self.task_resource.get_model_class(am_name +
'_inference')
if am_name == 'fastspeech2':
am = am_class(
@ -314,9 +316,9 @@ class TTSExecutor(BaseExecutor):
# vocoder
# model: {model_name}_{dataset}
voc_name = voc[:voc.rindex('_')]
voc_class = dynamic_import(voc_name, self.model_alias)
voc_inference_class = dynamic_import(voc_name + '_inference',
self.model_alias)
voc_class = self.task_resource.get_model_class(voc_name)
voc_inference_class = self.task_resource.get_model_class(voc_name +
'_inference')
if voc_name != 'wavernn':
voc = voc_class(**self.voc_config["generator_params"])
voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"])
@ -444,7 +446,7 @@ class TTSExecutor(BaseExecutor):
if not args.verbose:
self.disable_task_loggers()
task_source = self.get_task_source(args.input)
task_source = self.get_input_source(args.input)
task_results = OrderedDict()
has_exceptions = False

@ -1,300 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# speedyspeech
"speedyspeech_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip',
'md5':
'6f6fa967b408454b6662c8c00c0027cb',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_30600.pdz',
'speech_stats':
'feats_stats.npy',
'phones_dict':
'phone_id_map.txt',
'tones_dict':
'tone_id_map.txt',
},
# fastspeech2
"fastspeech2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip',
'md5':
'637d28a5e53aa60275612ba4393d5f22',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_76000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip',
'md5':
'ffed800c93deaf16ca9b3af89bfcd747',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_100000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip',
'md5':
'f4dd4a5f49a4552b77981f544ab3392e',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_96400.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
"fastspeech2_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip',
'md5':
'743e5024ca1e17a88c5c271db9779ba4',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_66200.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
# tacotron2
"tacotron2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip',
'md5':
'0df4b6f0bcbe0d73c5ed6df8867ab91a',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_30600.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"tacotron2_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip',
'md5':
'6a5eddd81ae0e81d16959b97481135f3',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_60300.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
# pwgan
"pwgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip',
'md5':
'2e481633325b5bdf0a3823c714d2c117',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
"pwgan_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip',
'md5':
'53610ba9708fd3008ccaf8e99dacbaf0',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
"pwgan_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip',
'md5':
'd7598fa41ad362d62f85ffc0f07e3d84',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
"pwgan_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip',
'md5':
'b3da1defcde3e578be71eb284cb89f2c',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'ee5f0604e20091f0d495b6ec4618b90d',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
# style_melgan
"style_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'5de2d5348f396de0c966926b8c462755',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
"hifigan_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip',
'md5':
'70e9131695decbca06a65fe51ed38a72',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
"hifigan_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip',
'md5':
'3bb49bc75032ed12f79c00c8cc79a09a',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
"hifigan_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip',
'md5':
'7da8f88359bca2457e705d924cf27bd4',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# wavernn
"wavernn_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip',
'md5':
'ee37b752f09bcba8f2af3b777ca38e13',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_400000.pdz',
'speech_stats':
'feats_stats.npy',
}
}
model_alias = {
# acoustic model
"speedyspeech":
"paddlespeech.t2s.models.speedyspeech:SpeedySpeech",
"speedyspeech_inference":
"paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference",
"fastspeech2":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2",
"fastspeech2_inference":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
"tacotron2":
"paddlespeech.t2s.models.tacotron2:Tacotron2",
"tacotron2_inference":
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
# voc
"pwgan":
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
"pwgan_inference":
"paddlespeech.t2s.models.parallel_wavegan:PWGInference",
"mb_melgan":
"paddlespeech.t2s.models.melgan:MelGANGenerator",
"mb_melgan_inference":
"paddlespeech.t2s.models.melgan:MelGANInference",
"style_melgan":
"paddlespeech.t2s.models.melgan:StyleMelGANGenerator",
"style_melgan_inference":
"paddlespeech.t2s.models.melgan:StyleMelGANInference",
"hifigan":
"paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
"hifigan_inference":
"paddlespeech.t2s.models.hifigan:HiFiGANInference",
"wavernn":
"paddlespeech.t2s.models.wavernn:WaveRNN",
"wavernn_inference":
"paddlespeech.t2s.models.wavernn:WaveRNNInference",
}

@ -22,26 +22,20 @@ from typing import Union
import paddle
import soundfile
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode
from ..executor import BaseExecutor
from ..log import logger
from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from paddlespeech.utils.dynamic_import import dynamic_import
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
class VectorExecutor(BaseExecutor):
def __init__(self):
super().__init__()
self.model_alias = model_alias
self.pretrained_models = pretrained_models
super().__init__('vector')
self.parser = argparse.ArgumentParser(
prog="paddlespeech.vector", add_help=True)
@ -49,7 +43,10 @@ class VectorExecutor(BaseExecutor):
"--model",
type=str,
default="ecapatdnn_voxceleb12",
choices=["ecapatdnn_voxceleb12"],
choices=[
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
],
help="Choose model type of vector task.")
self.parser.add_argument(
"--task",
@ -119,7 +116,7 @@ class VectorExecutor(BaseExecutor):
self.disable_task_loggers()
# stage 2: read the input data and store them as a list
task_source = self.get_task_source(parser_args.input)
task_source = self.get_input_source(parser_args.input)
logger.info(f"task source: {task_source}")
# stage 3: process the audio one by one
@ -296,6 +293,7 @@ class VectorExecutor(BaseExecutor):
# get the mode from pretrained list
sample_rate_str = "16k" if sample_rate == 16000 else "8k"
tag = model_type + "-" + sample_rate_str
self.task_resource.set_task_model(tag, version=None)
logger.info(f"load the pretrained model: {tag}")
# get the model from the pretrained list
# we download the pretrained model and store it in the res_path
@ -303,10 +301,11 @@ class VectorExecutor(BaseExecutor):
self.res_path = res_path
self.cfg_path = os.path.join(
res_path, self.pretrained_models[tag]['cfg_path'])
self.task_resource.res_dir,
self.task_resource.res_dict['cfg_path'])
self.ckpt_path = os.path.join(
res_path,
self.pretrained_models[tag]['ckpt_path'] + '.pdparams')
self.task_resource.res_dir,
self.task_resource.res_dict['ckpt_path'] + '.pdparams')
else:
# get the model from disk
self.cfg_path = os.path.abspath(cfg_path)
@ -325,8 +324,8 @@ class VectorExecutor(BaseExecutor):
# stage 3: get the model name to instance the model network with dynamic_import
logger.info("start to dynamic import the model class")
model_name = model_type[:model_type.rindex('_')]
model_class = self.task_resource.get_model_class(model_name)
logger.info(f"model name {model_name}")
model_class = dynamic_import(model_name, self.model_alias)
model_conf = self.config.model
backbone = model_class(**model_conf)
model = SpeakerIdetification(

@ -1,36 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
# The tags for pretrained_models should be "{model_name}[-{dataset}][-{sr}][-...]".
# e.g. "ecapatdnn_voxceleb12-16k".
# Command line and python api use "{model_name}[-{dataset}]" as --model, usage:
# "paddlespeech vector --task spk --model ecapatdnn_voxceleb12-16k --sr 16000 --input ./input.wav"
"ecapatdnn_voxceleb12-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_1.tar.gz',
'md5':
'67c7ff8885d5246bd16e0f5ac1cba99f',
'cfg_path':
'conf/model.yaml', # the yaml config path
'ckpt_path':
'model/model', # the format is ${dir}/{model_name},
# so the first 'model' is dir, the second 'model' is the name
# this means we have a model stored as model/model.pdparams
},
}
model_alias = {
"ecapatdnn": "paddlespeech.vector.models.ecapa_tdnn:EcapaTdnn",
}

@ -1,4 +1,4 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .infer import StatsExecutor
from .resource import CommonTaskResource

@ -0,0 +1,822 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
'asr_dynamic_pretrained_models',
'asr_static_pretrained_models',
'cls_dynamic_pretrained_models',
'cls_static_pretrained_models',
'st_dynamic_pretrained_models',
'st_kaldi_bins',
'text_dynamic_pretrained_models',
'tts_dynamic_pretrained_models',
'tts_static_pretrained_models',
'tts_onnx_pretrained_models',
'vector_dynamic_pretrained_models',
]
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
# e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k".
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
# ---------------------------------
# -------------- ASR --------------
# ---------------------------------
asr_dynamic_pretrained_models = {
"conformer_wenetspeech-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz',
'md5':
'76cb19ed857e6623856b7cd7ebbfeda4',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/conformer/checkpoints/wenetspeech',
},
},
"conformer_online_wenetspeech-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz',
'md5':
'b8c02632b04da34aca88459835be54a6',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_10',
},
},
"conformer_online_multicn-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz',
'md5':
'7989b3248c898070904cf042fd656003',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/multi_cn',
},
'2.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz',
'md5':
'0ac93d390552336f2a906aec9e33c5fa',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/multi_cn',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"conformer_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz',
'md5':
'3f073eccfa7bb14e0c6867d65fc0dc3a',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/conformer/checkpoints/avg_30',
},
},
"conformer_online_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz',
'md5':
'b374cfb93537761270b6224fb0bfc26a',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_30',
},
},
"transformer_librispeech-en-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'2c667da24922aad391eacafe37bc1660',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/transformer/checkpoints/avg_10',
},
},
"deepspeech2online_wenetspeech-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz',
'md5':
'e393d4d274af0f6967db24fc146e8074',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_10',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2offline_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2online_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz',
'md5':
'98b87b171b7240b7cae6e07d8d0bc9be',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
},
"deepspeech2offline_librispeech-en-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz',
'md5':
'f5666c81ad015c8de03aac2bc92e5762',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'lm_url':
'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm',
'lm_md5':
'099a601759d467cd0a8523ff939819c5'
},
},
}
asr_static_pretrained_models = {
"deepspeech2offline_aishell-zh-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'model':
'exp/deepspeech2/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
}
},
}
# ---------------------------------
# -------------- CLS --------------
# ---------------------------------
cls_dynamic_pretrained_models = {
"panns_cnn6-32k": {
'1.0': {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz',
'md5': '4cf09194a95df024fd12f84712cf0f9c',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn6.pdparams',
'label_file': 'audioset_labels.txt',
},
},
"panns_cnn10-32k": {
'1.0': {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz',
'md5': 'cb8427b22176cc2116367d14847f5413',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn10.pdparams',
'label_file': 'audioset_labels.txt',
},
},
"panns_cnn14-32k": {
'1.0': {
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz',
'md5': 'e3b9b5614a1595001161d0ab95edee97',
'cfg_path': 'panns.yaml',
'ckpt_path': 'cnn14.pdparams',
'label_file': 'audioset_labels.txt',
},
},
}
cls_static_pretrained_models = {
"panns_cnn6-32k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz',
'md5':
'da087c31046d23281d8ec5188c1967da',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
},
"panns_cnn10-32k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz',
'md5':
'5460cc6eafbfaf0f261cc75b90284ae1',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
},
"panns_cnn14-32k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz',
'md5':
'ccc80b194821274da79466862b2ab00f',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
},
}
# ---------------------------------
# -------------- ST ---------------
# ---------------------------------
st_dynamic_pretrained_models = {
"fat_st_ted-en-zh": {
'1.0': {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz",
"md5":
"d62063f35a16d91210a71081bd2dd557",
"cfg_path":
"model.yaml",
"ckpt_path":
"exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams",
},
},
}
st_kaldi_bins = {
"url":
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz",
"md5":
"c0682303b3f3393dbf6ed4c4e35a53eb",
}
# ---------------------------------
# -------------- TEXT -------------
# ---------------------------------
text_dynamic_pretrained_models = {
"ernie_linear_p7_wudao-punc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p7_wudao-punc-zh.tar.gz',
'md5':
'12283e2ddde1797c5d1e57036b512746',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
},
"ernie_linear_p3_wudao-punc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/text/ernie_linear_p3_wudao-punc-zh.tar.gz',
'md5':
'448eb2fdf85b6a997e7e652e80c51dd2',
'cfg_path':
'ckpt/model_config.json',
'ckpt_path':
'ckpt/model_state.pdparams',
'vocab_file':
'punc_vocab.txt',
},
},
}
# ---------------------------------
# -------------- TTS --------------
# ---------------------------------
tts_dynamic_pretrained_models = {
# speedyspeech
"speedyspeech_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip',
'md5':
'6f6fa967b408454b6662c8c00c0027cb',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_30600.pdz',
'speech_stats':
'feats_stats.npy',
'phones_dict':
'phone_id_map.txt',
'tones_dict':
'tone_id_map.txt',
},
},
# fastspeech2
"fastspeech2_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip',
'md5':
'637d28a5e53aa60275612ba4393d5f22',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_76000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
},
"fastspeech2_ljspeech-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip',
'md5':
'ffed800c93deaf16ca9b3af89bfcd747',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_100000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
},
"fastspeech2_aishell3-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip',
'md5':
'f4dd4a5f49a4552b77981f544ab3392e',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_96400.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
},
"fastspeech2_vctk-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip',
'md5':
'743e5024ca1e17a88c5c271db9779ba4',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_66200.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
},
# tacotron2
"tacotron2_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip',
'md5':
'0df4b6f0bcbe0d73c5ed6df8867ab91a',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_30600.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
},
"tacotron2_ljspeech-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip',
'md5':
'6a5eddd81ae0e81d16959b97481135f3',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_60300.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
},
# pwgan
"pwgan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip',
'md5':
'2e481633325b5bdf0a3823c714d2c117',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
},
"pwgan_ljspeech-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip',
'md5':
'53610ba9708fd3008ccaf8e99dacbaf0',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
},
"pwgan_aishell3-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip',
'md5':
'd7598fa41ad362d62f85ffc0f07e3d84',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
"pwgan_vctk-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.1.1.zip',
'md5':
'b3da1defcde3e578be71eb284cb89f2c',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'ee5f0604e20091f0d495b6ec4618b90d',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
# style_melgan
"style_melgan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'5de2d5348f396de0c966926b8c462755',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
# hifigan
"hifigan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
"hifigan_ljspeech-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_ljspeech_ckpt_0.2.0.zip',
'md5':
'70e9131695decbca06a65fe51ed38a72',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
"hifigan_aishell3-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_aishell3_ckpt_0.2.0.zip',
'md5':
'3bb49bc75032ed12f79c00c8cc79a09a',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
"hifigan_vctk-en": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_vctk_ckpt_0.2.0.zip',
'md5':
'7da8f88359bca2457e705d924cf27bd4',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
# wavernn
"wavernn_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/wavernn/wavernn_csmsc_ckpt_0.2.0.zip',
'md5':
'ee37b752f09bcba8f2af3b777ca38e13',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_400000.pdz',
'speech_stats':
'feats_stats.npy',
},
},
"fastspeech2_cnndecoder_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip',
'md5':
'6eb28e22ace73e0ebe7845f86478f89f',
'config':
'cnndecoder.yaml',
'ckpt':
'snapshot_iter_153000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
},
}
tts_static_pretrained_models = {
# speedyspeech
"speedyspeech_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip',
'md5':
'f10cbdedf47dc7a9668d2264494e1823',
'model':
'speedyspeech_csmsc.pdmodel',
'params':
'speedyspeech_csmsc.pdiparams',
'phones_dict':
'phone_id_map.txt',
'tones_dict':
'tone_id_map.txt',
'sample_rate':
24000,
},
},
# fastspeech2
"fastspeech2_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip',
'md5':
'9788cd9745e14c7a5d12d32670b2a5a7',
'model':
'fastspeech2_csmsc.pdmodel',
'params':
'fastspeech2_csmsc.pdiparams',
'phones_dict':
'phone_id_map.txt',
'sample_rate':
24000,
},
},
# pwgan
"pwgan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip',
'md5':
'e3504aed9c5a290be12d1347836d2742',
'model':
'pwgan_csmsc.pdmodel',
'params':
'pwgan_csmsc.pdiparams',
'sample_rate':
24000,
},
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip',
'md5':
'ac6eee94ba483421d750433f4c3b8d36',
'model':
'mb_melgan_csmsc.pdmodel',
'params':
'mb_melgan_csmsc.pdiparams',
'sample_rate':
24000,
},
},
# hifigan
"hifigan_csmsc-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip',
'md5':
'7edd8c436b3a5546b3a7cb8cff9d5a0c',
'model':
'hifigan_csmsc.pdmodel',
'params':
'hifigan_csmsc.pdiparams',
'sample_rate':
24000,
},
},
}
tts_onnx_pretrained_models = {
# fastspeech2
"fastspeech2_csmsc_onnx-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip',
'md5':
'fd3ad38d83273ad51f0ea4f4abf3ab4e',
'ckpt': ['fastspeech2_csmsc.onnx'],
'phones_dict':
'phone_id_map.txt',
'sample_rate':
24000,
},
},
"fastspeech2_cnndecoder_csmsc_onnx-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip',
'md5':
'5f70e1a6bcd29d72d54e7931aa86f266',
'ckpt': [
'fastspeech2_csmsc_am_encoder_infer.onnx',
'fastspeech2_csmsc_am_decoder.onnx',
'fastspeech2_csmsc_am_postnet.onnx',
],
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'sample_rate':
24000,
},
},
# mb_melgan
"mb_melgan_csmsc_onnx-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip',
'md5':
'5b83ec746e8414bc29032d954ffd07ec',
'ckpt':
'mb_melgan_csmsc.onnx',
'sample_rate':
24000,
},
},
# hifigan
"hifigan_csmsc_onnx-zh": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip',
'md5':
'1a7dc0385875889e46952e50c0994a6b',
'ckpt':
'hifigan_csmsc.onnx',
'sample_rate':
24000,
},
},
}
# ---------------------------------
# ------------ Vector -------------
# ---------------------------------
vector_dynamic_pretrained_models = {
"ecapatdnn_voxceleb12-16k": {
'1.0': {
'url':
'https://paddlespeech.bj.bcebos.com/vector/voxceleb/sv0_ecapa_tdnn_voxceleb12_ckpt_0_2_0.tar.gz',
'md5':
'cc33023c54ab346cd318408f43fcaf95',
'cfg_path':
'conf/model.yaml', # the yaml config path
'ckpt_path':
'model/model', # the format is ${dir}/{model_name},
# so the first 'model' is dir, the second 'model' is the name
# this means we have a model stored as model/model.pdparams
},
},
}

@ -0,0 +1,222 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import OrderedDict
from typing import Dict
from typing import List
from typing import Optional
from ..cli.utils import download_and_decompress
from ..cli.utils import MODEL_HOME
from ..utils.dynamic_import import dynamic_import
from .model_alias import model_alias
task_supported = ['asr', 'cls', 'st', 'text', 'tts', 'vector']
model_format_supported = ['dynamic', 'static', 'onnx']
inference_mode_supported = ['online', 'offline']
class CommonTaskResource:
def __init__(self, task: str, model_format: str='dynamic', **kwargs):
assert task in task_supported, 'Arg "task" must be one of {}.'.format(
task_supported)
assert model_format in model_format_supported, 'Arg "model_format" must be one of {}.'.format(
model_format_supported)
self.task = task
self.model_format = model_format
self.pretrained_models = self._get_pretrained_models()
if 'inference_mode' in kwargs:
assert kwargs[
'inference_mode'] in inference_mode_supported, 'Arg "inference_mode" must be one of {}.'.format(
inference_mode_supported)
self._inference_mode_filter(kwargs['inference_mode'])
# Initialize after model and version had been set.
self.model_tag = None
self.version = None
self.res_dict = None
self.res_dir = None
if self.task == 'tts':
# For vocoder
self.voc_model_tag = None
self.voc_version = None
self.voc_res_dict = None
self.voc_res_dir = None
def set_task_model(self,
model_tag: str,
model_type: int=0,
version: Optional[str]=None):
"""Set model tag and version of current task.
Args:
model_tag (str): Model tag.
model_type (int): 0 for acoustic model otherwise vocoder in tts task.
version (Optional[str], optional): Version of pretrained model. Defaults to None.
"""
assert model_tag in self.pretrained_models, \
"Can't find \"{}\" in resource. Model name must be one of {}".format(model_tag, list(self.pretrained_models.keys()))
if version is None:
version = self._get_default_version(model_tag)
assert version in self.pretrained_models[model_tag], \
"Can't find version \"{}\" in \"{}\". Model name must be one of {}".format(
version, model_tag, list(self.pretrained_models[model_tag].keys()))
if model_type == 0:
self.model_tag = model_tag
self.version = version
self.res_dict = self.pretrained_models[model_tag][version]
self.res_dir = self._fetch(self.res_dict,
self._get_model_dir(model_type))
else:
assert self.task == 'tts', 'Vocoder will only be used in tts task.'
self.voc_model_tag = model_tag
self.voc_version = version
self.voc_res_dict = self.pretrained_models[model_tag][version]
self.voc_res_dir = self._fetch(self.voc_res_dict,
self._get_model_dir(model_type))
@staticmethod
def get_model_class(model_name) -> List[object]:
"""Dynamic import model class.
Args:
model_name (str): Model name.
Returns:
List[object]: Return a list of model class.
"""
assert model_name in model_alias, 'No model classes found for "{}"'.format(
model_name)
ret = []
for import_path in model_alias[model_name]:
ret.append(dynamic_import(import_path))
if len(ret) == 1:
return ret[0]
else:
return ret
def get_versions(self, model_tag: str) -> List[str]:
"""List all available versions.
Args:
model_tag (str): Model tag.
Returns:
List[str]: Version list of model.
"""
return list(self.pretrained_models[model_tag].keys())
def _get_default_version(self, model_tag: str) -> str:
"""Get default version of model.
Args:
model_tag (str): Model tag.
Returns:
str: Default version.
"""
return self.get_versions(model_tag)[-1] # get latest version
def _get_model_dir(self, model_type: int=0) -> os.PathLike:
"""Get resource directory.
Args:
model_type (int): 0 for acoustic model otherwise vocoder in tts task.
Returns:
os.PathLike: Directory of model resource.
"""
if model_type == 0:
model_tag = self.model_tag
version = self.version
else:
model_tag = self.voc_model_tag
version = self.voc_version
return os.path.join(MODEL_HOME, model_tag, version)
def _get_pretrained_models(self) -> Dict[str, str]:
"""Get all available models for current task.
Returns:
Dict[str, str]: A dictionary with model tag and resources info.
"""
try:
import_models = '{}_{}_pretrained_models'.format(self.task,
self.model_format)
exec('from .pretrained_models import {}'.format(import_models))
models = OrderedDict(locals()[import_models])
except ImportError:
models = OrderedDict({}) # no models.
finally:
return models
def _inference_mode_filter(self, inference_mode: Optional[str]):
"""Filter models dict based on inference_mode.
Args:
inference_mode (Optional[str]): 'online', 'offline' or None.
"""
if inference_mode is None:
return
if self.task == 'asr':
online_flags = [
'online' in model_tag
for model_tag in self.pretrained_models.keys()
]
for online_flag, model_tag in zip(
online_flags, list(self.pretrained_models.keys())):
if inference_mode == 'online' and online_flag:
continue
elif inference_mode == 'offline' and not online_flag:
continue
else:
del self.pretrained_models[model_tag]
elif self.task == 'tts':
# Hardcode for tts online models.
tts_online_models = [
'fastspeech2_csmsc-zh', 'fastspeech2_cnndecoder_csmsc-zh',
'mb_melgan_csmsc-zh', 'hifigan_csmsc-zh'
]
for model_tag in list(self.pretrained_models.keys()):
if inference_mode == 'online' and model_tag in tts_online_models:
continue
elif inference_mode == 'offline':
continue
else:
del self.pretrained_models[model_tag]
else:
raise NotImplementedError('Only supports asr and tts task.')
@staticmethod
def _fetch(res_dict: Dict[str, str],
target_dir: os.PathLike) -> os.PathLike:
"""Fetch archive from url.
Args:
res_dict (Dict[str, str]): Info dict of a resource.
target_dir (os.PathLike): Directory to save archives.
Returns:
os.PathLike: Directory of model resource.
"""
return download_and_decompress(res_dict, target_dir)

@ -25,6 +25,7 @@ from ..executor import BaseExecutor
from ..util import cli_server_register
from ..util import stats_wrapper
from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.utils.config import get_config
@ -152,101 +153,30 @@ class ServerStatsExecutor():
"Please input correct speech task, choices = ['asr', 'tts']")
return False
elif self.task.lower() == 'asr':
try:
from paddlespeech.cli.asr.infer import pretrained_models
logger.info(
"Here is the table of ASR pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
# show ASR static pretrained model
from paddlespeech.server.engine.asr.paddleinference.asr_engine import pretrained_models
logger.info(
"Here is the table of ASR static pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error(
"Failed to get the table of ASR pretrained models supported in the service."
)
return False
elif self.task.lower() == 'tts':
try:
from paddlespeech.cli.tts.infer import pretrained_models
logger.info(
"Here is the table of TTS pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
# show TTS static pretrained model
from paddlespeech.server.engine.tts.paddleinference.tts_engine import pretrained_models
logger.info(
"Here is the table of TTS static pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error(
"Failed to get the table of TTS pretrained models supported in the service."
)
return False
try:
# Dynamic models
dynamic_pretrained_models = CommonTaskResource(
task=self.task, model_format='dynamic').pretrained_models
elif self.task.lower() == 'cls':
try:
from paddlespeech.cli.cls.infer import pretrained_models
if len(dynamic_pretrained_models) > 0:
logger.info(
"Here is the table of CLS pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
# show CLS static pretrained model
from paddlespeech.server.engine.cls.paddleinference.cls_engine import pretrained_models
"Here is the table of {} pretrained models supported in the service.".
format(self.task.upper()))
self.show_support_models(dynamic_pretrained_models)
# Static models
static_pretrained_models = CommonTaskResource(
task=self.task, model_format='static').pretrained_models
if len(static_pretrained_models) > 0:
logger.info(
"Here is the table of CLS static pretrained models supported in the service."
)
"Here is the table of {} static pretrained models supported in the service.".
format(self.task.upper()))
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error(
"Failed to get the table of CLS pretrained models supported in the service."
)
return False
elif self.task.lower() == 'text':
try:
from paddlespeech.cli.text.infer import pretrained_models
logger.info(
"Here is the table of Text pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
return True
return True
except BaseException:
logger.error(
"Failed to get the table of Text pretrained models supported in the service."
)
return False
elif self.task.lower() == 'vector':
try:
from paddlespeech.cli.vector.infer import pretrained_models
logger.info(
"Here is the table of Vector pretrained models supported in the service."
)
self.show_support_models(pretrained_models)
return True
except BaseException:
logger.error(
"Failed to get the table of Vector pretrained models supported in the service."
)
return False
else:
except BaseException:
logger.error(
f"Failed to get the table of {self.task} pretrained models supported in the service."
)
"Failed to get the table of {} pretrained models supported in the service.".
format(self.task.upper()))
return False

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import sys
from typing import Optional
@ -21,15 +20,14 @@ import paddle
from numpy import float32
from yacs.config import CfgNode
from .pretrained_models import pretrained_models
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.resource import CommonTaskResource
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
from paddlespeech.s2t.utils.tensor_utils import pad_sequence
from paddlespeech.s2t.utils.utility import UpdateConfig
@ -53,7 +51,7 @@ class PaddleASRConnectionHanddler:
logger.info(
"create an paddle asr connection handler to process the websocket connection"
)
self.config = asr_engine.config # server config
self.config = asr_engine.config # server config
self.model_config = asr_engine.executor.config
self.asr_engine = asr_engine
@ -251,10 +249,12 @@ class PaddleASRConnectionHanddler:
# for deepspeech2
# init state
self.chunk_state_h_box = np.zeros(
(self.model_config .num_rnn_layers, 1, self.model_config.rnn_layer_size),
(self.model_config.num_rnn_layers, 1,
self.model_config.rnn_layer_size),
dtype=float32)
self.chunk_state_c_box = np.zeros(
(self.model_config.num_rnn_layers, 1, self.model_config.rnn_layer_size),
(self.model_config.num_rnn_layers, 1,
self.model_config.rnn_layer_size),
dtype=float32)
self.decoder.reset_decoder(batch_size=1)
@ -699,7 +699,8 @@ class PaddleASRConnectionHanddler:
class ASRServerExecutor(ASRExecutor):
def __init__(self):
super().__init__()
self.pretrained_models = pretrained_models
self.task_resource = CommonTaskResource(
task='asr', model_format='dynamic', inference_mode='online')
def _init_from_path(self,
model_type: str=None,
@ -723,20 +724,19 @@ class ASRServerExecutor(ASRExecutor):
self.sample_rate = sample_rate
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(model_tag=tag)
if cfg_path is None or am_model is None or am_params is None:
logger.info(f"Load the pretrained model, tag = {tag}")
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.res_path = self.task_resource.res_dir
self.cfg_path = os.path.join(
res_path, self.pretrained_models[tag]['cfg_path'])
self.res_path, self.task_resource.res_dict['cfg_path'])
self.am_model = os.path.join(res_path,
self.pretrained_models[tag]['model'])
self.am_params = os.path.join(res_path,
self.pretrained_models[tag]['params'])
logger.info(res_path)
self.am_model = os.path.join(self.res_path,
self.task_resource.res_dict['model'])
self.am_params = os.path.join(self.res_path,
self.task_resource.res_dict['params'])
logger.info(self.res_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model)
@ -763,8 +763,8 @@ class ASRServerExecutor(ASRExecutor):
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, vocab=self.vocab)
lm_url = self.pretrained_models[tag]['lm_url']
lm_md5 = self.pretrained_models[tag]['lm_md5']
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
@ -810,7 +810,7 @@ class ASRServerExecutor(ASRExecutor):
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
logger.info(f"model name: {model_name}")
model_class = dynamic_import(model_name, self.model_alias)
model_class = self.task_resource.get_model_class(model_name)
model_conf = self.config
model = model_class.from_config(model_conf)
self.model = model
@ -824,7 +824,7 @@ class ASRServerExecutor(ASRExecutor):
raise ValueError(f"Not support: {model_type}")
return True
class ASREngine(BaseEngine):
"""ASR server resource

@ -1,70 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
"deepspeech2online_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz',
'md5':
'98b87b171b7240b7cae6e07d8d0bc9be',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2_online/checkpoints/avg_1',
'model':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"conformer_online_multicn-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz',
'md5':
'0ac93d390552336f2a906aec9e33c5fa',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/multi_cn',
'model':
'exp/chunk_conformer/checkpoints/multi_cn.pdparams',
'params':
'exp/chunk_conformer/checkpoints/multi_cn.pdparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
"conformer_online_wenetspeech-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz',
'md5':
'b8c02632b04da34aca88459835be54a6',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/avg_10',
'model':
'exp/chunk_conformer/checkpoints/avg_10.pdparams',
'params':
'exp/chunk_conformer/checkpoints/avg_10.pdparams',
'lm_url':
'',
'lm_md5':
'',
},
}

@ -19,10 +19,10 @@ from typing import Optional
import paddle
from yacs.config import CfgNode
from .pretrained_models import pretrained_models
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.resource import CommonTaskResource
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.utility import UpdateConfig
@ -36,7 +36,8 @@ __all__ = ['ASREngine']
class ASRServerExecutor(ASRExecutor):
def __init__(self):
super().__init__()
self.pretrained_models = pretrained_models
self.task_resource = CommonTaskResource(
task='asr', model_format='static', inference_mode='online')
def _init_from_path(self,
model_type: str='wenetspeech',
@ -53,17 +54,17 @@ class ASRServerExecutor(ASRExecutor):
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(model_tag=tag)
if cfg_path is None or am_model is None or am_params is None:
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.res_path = self.task_resource.res_dir
self.cfg_path = os.path.join(
res_path, self.pretrained_models[tag]['cfg_path'])
self.res_path, self.task_resource.res_dict['cfg_path'])
self.am_model = os.path.join(res_path,
self.pretrained_models[tag]['model'])
self.am_params = os.path.join(res_path,
self.pretrained_models[tag]['params'])
logger.info(res_path)
self.am_model = os.path.join(self.res_path,
self.task_resource.res_dict['model'])
self.am_params = os.path.join(self.res_path,
self.task_resource.res_dict['params'])
logger.info(self.res_path)
logger.info(self.cfg_path)
logger.info(self.am_model)
logger.info(self.am_params)
@ -89,8 +90,8 @@ class ASRServerExecutor(ASRExecutor):
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type, vocab=self.vocab)
lm_url = self.pretrained_models[tag]['lm_url']
lm_md5 = self.pretrained_models[tag]['lm_md5']
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)

@ -1,34 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
"deepspeech2offline_aishell-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
'md5':
'932c3593d62fe5c741b59b31318aa314',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/deepspeech2/checkpoints/avg_1',
'model':
'exp/deepspeech2/checkpoints/avg_1.jit.pdmodel',
'params':
'exp/deepspeech2/checkpoints/avg_1.jit.pdiparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
}

@ -20,9 +20,9 @@ import numpy as np
import paddle
import yaml
from .pretrained_models import pretrained_models
from paddlespeech.cli.cls.infer import CLSExecutor
from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.paddle_predictor import init_predictor
from paddlespeech.server.utils.paddle_predictor import run_model
@ -33,11 +33,12 @@ __all__ = ['CLSEngine']
class CLSServerExecutor(CLSExecutor):
def __init__(self):
super().__init__()
self.pretrained_models = pretrained_models
self.task_resource = CommonTaskResource(
task='cls', model_format='static')
def _init_from_path(
self,
model_type: str='panns_cnn14',
model_type: str='panns_cnn14_audioset',
cfg_path: Optional[os.PathLike]=None,
model_path: Optional[os.PathLike]=None,
params_path: Optional[os.PathLike]=None,
@ -49,15 +50,16 @@ class CLSServerExecutor(CLSExecutor):
if cfg_path is None or model_path is None or params_path is None or label_file is None:
tag = model_type + '-' + '32k'
self.res_path = self._get_pretrained_path(tag)
self.task_resource.set_task_model(model_tag=tag)
self.res_path = self.task_resource.res_dir
self.cfg_path = os.path.join(
self.res_path, self.pretrained_models[tag]['cfg_path'])
self.res_path, self.task_resource.res_dict['cfg_path'])
self.model_path = os.path.join(
self.res_path, self.pretrained_models[tag]['model_path'])
self.res_path, self.task_resource.res_dict['model_path'])
self.params_path = os.path.join(
self.res_path, self.pretrained_models[tag]['params_path'])
self.res_path, self.task_resource.res_dict['params_path'])
self.label_file = os.path.join(
self.res_path, self.pretrained_models[tag]['label_file'])
self.res_path, self.task_resource.res_dict['label_file'])
else:
self.cfg_path = os.path.abspath(cfg_path)
self.model_path = os.path.abspath(model_path)

@ -1,58 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pretrained_models = {
"panns_cnn6-32k": {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn6_static.tar.gz',
'md5':
'da087c31046d23281d8ec5188c1967da',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
"panns_cnn10-32k": {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn10_static.tar.gz',
'md5':
'5460cc6eafbfaf0f261cc75b90284ae1',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
"panns_cnn14-32k": {
'url':
'https://paddlespeech.bj.bcebos.com/cls/inference_model/panns_cnn14_static.tar.gz',
'md5':
'ccc80b194821274da79466862b2ab00f',
'cfg_path':
'panns.yaml',
'model_path':
'inference.pdmodel',
'params_path':
'inference.pdiparams',
'label_file':
'audioset_labels.txt',
},
}

@ -1,69 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# support online model
pretrained_models = {
# fastspeech2
"fastspeech2_csmsc_onnx-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_onnx_0.2.0.zip',
'md5':
'fd3ad38d83273ad51f0ea4f4abf3ab4e',
'ckpt': ['fastspeech2_csmsc.onnx'],
'phones_dict':
'phone_id_map.txt',
'sample_rate':
24000,
},
"fastspeech2_cnndecoder_csmsc_onnx-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_streaming_onnx_1.0.0.zip',
'md5':
'5f70e1a6bcd29d72d54e7931aa86f266',
'ckpt': [
'fastspeech2_csmsc_am_encoder_infer.onnx',
'fastspeech2_csmsc_am_decoder.onnx',
'fastspeech2_csmsc_am_postnet.onnx',
],
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'sample_rate':
24000,
},
# mb_melgan
"mb_melgan_csmsc_onnx-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_onnx_0.2.0.zip',
'md5':
'5b83ec746e8414bc29032d954ffd07ec',
'ckpt':
'mb_melgan_csmsc.onnx',
'sample_rate':
24000,
},
# hifigan
"hifigan_csmsc_onnx-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_onnx_0.2.0.zip',
'md5':
'1a7dc0385875889e46952e50c0994a6b',
'ckpt':
'hifigan_csmsc.onnx',
'sample_rate':
24000,
},
}

@ -20,9 +20,9 @@ from typing import Optional
import numpy as np
import paddle
from .pretrained_models import pretrained_models
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.resource import CommonTaskResource
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.onnx_infer import get_sess
@ -43,7 +43,7 @@ class TTSServerExecutor(TTSExecutor):
self.voc_pad = voc_pad
self.voc_upsample = voc_upsample
self.pretrained_models = pretrained_models
self.task_resource = CommonTaskResource(task='tts', model_format='onnx')
def _init_from_path(
self,
@ -72,16 +72,21 @@ class TTSServerExecutor(TTSExecutor):
return
# am
am_tag = am + '-' + lang
self.task_resource.set_task_model(
model_tag=am_tag,
model_type=0, # am
version=None, # default version
)
self.am_res_path = self.task_resource.res_dir
if am == "fastspeech2_csmsc_onnx":
# get model info
if am_ckpt is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag)
self.am_res_path = am_res_path
self.am_ckpt = os.path.join(
am_res_path, self.pretrained_models[am_tag]['ckpt'][0])
self.am_res_path, self.task_resource.res_dict['ckpt'][0])
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['phones_dict'])
self.am_res_path,
self.task_resource.res_dict['phones_dict'])
else:
self.am_ckpt = os.path.abspath(am_ckpt[0])
@ -94,19 +99,19 @@ class TTSServerExecutor(TTSExecutor):
elif am == "fastspeech2_cnndecoder_csmsc_onnx":
if am_ckpt is None or am_stat is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag)
self.am_res_path = am_res_path
self.am_encoder_infer = os.path.join(
am_res_path, self.pretrained_models[am_tag]['ckpt'][0])
self.am_res_path, self.task_resource.res_dict['ckpt'][0])
self.am_decoder = os.path.join(
am_res_path, self.pretrained_models[am_tag]['ckpt'][1])
self.am_res_path, self.task_resource.res_dict['ckpt'][1])
self.am_postnet = os.path.join(
am_res_path, self.pretrained_models[am_tag]['ckpt'][2])
self.am_res_path, self.task_resource.res_dict['ckpt'][2])
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['phones_dict'])
self.am_res_path,
self.task_resource.res_dict['phones_dict'])
self.am_stat = os.path.join(
am_res_path, self.pretrained_models[am_tag]['speech_stats'])
self.am_res_path,
self.task_resource.res_dict['speech_stats'])
else:
self.am_encoder_infer = os.path.abspath(am_ckpt[0])
@ -131,11 +136,15 @@ class TTSServerExecutor(TTSExecutor):
# voc model info
voc_tag = voc + '-' + lang
self.task_resource.set_task_model(
model_tag=voc_tag,
model_type=1, # vocoder
version=None, # default version
)
if voc_ckpt is None:
voc_res_path = self._get_pretrained_path(voc_tag)
self.voc_res_path = voc_res_path
self.voc_res_path = self.task_resource.voc_res_dir
self.voc_ckpt = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['ckpt'])
self.voc_res_path, self.task_resource.voc_res_dict['ckpt'])
else:
self.voc_ckpt = os.path.abspath(voc_ckpt)
self.voc_res_path = os.path.dirname(os.path.abspath(self.voc_ckpt))

@ -1,73 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# support online model
pretrained_models = {
# fastspeech2
"fastspeech2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip',
'md5':
'637d28a5e53aa60275612ba4393d5f22',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_76000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_cnndecoder_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip',
'md5':
'6eb28e22ace73e0ebe7845f86478f89f',
'config':
'cnndecoder.yaml',
'ckpt':
'snapshot_iter_153000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'ee5f0604e20091f0d495b6ec4618b90d',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
}

@ -22,9 +22,9 @@ import paddle
import yaml
from yacs.config import CfgNode
from .pretrained_models import pretrained_models
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.resource import CommonTaskResource
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.util import denorm
@ -32,7 +32,6 @@ from paddlespeech.server.utils.util import get_chunks
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import
__all__ = ['TTSEngine']
@ -44,7 +43,8 @@ class TTSServerExecutor(TTSExecutor):
self.am_pad = am_pad
self.voc_block = voc_block
self.voc_pad = voc_pad
self.pretrained_models = pretrained_models
self.task_resource = CommonTaskResource(
task='tts', model_format='static', inference_mode='online')
def get_model_info(self,
field: str,
@ -65,7 +65,7 @@ class TTSServerExecutor(TTSExecutor):
[Tensor]: standard deviation
"""
model_class = dynamic_import(model_name, self.model_alias)
model_class = self.task_resource.get_model_class(model_name)
if field == "am":
odim = self.am_config.n_mels
@ -110,20 +110,24 @@ class TTSServerExecutor(TTSExecutor):
return
# am model info
am_tag = am + '-' + lang
self.task_resource.set_task_model(
model_tag=am_tag,
model_type=0, # am
version=None, # default version
)
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag)
self.am_res_path = am_res_path
self.am_config = os.path.join(
am_res_path, self.pretrained_models[am_tag]['config'])
self.am_ckpt = os.path.join(am_res_path,
self.pretrained_models[am_tag]['ckpt'])
self.am_res_path = self.task_resource.res_dir
self.am_config = os.path.join(self.am_res_path,
self.task_resource.res_dict['config'])
self.am_ckpt = os.path.join(self.am_res_path,
self.task_resource.res_dict['ckpt'])
self.am_stat = os.path.join(
am_res_path, self.pretrained_models[am_tag]['speech_stats'])
self.am_res_path, self.task_resource.res_dict['speech_stats'])
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['phones_dict'])
self.am_res_path, self.task_resource.res_dict['phones_dict'])
print("self.phones_dict:", self.phones_dict)
logger.info(am_res_path)
logger.info(self.am_res_path)
logger.info(self.am_config)
logger.info(self.am_ckpt)
else:
@ -139,16 +143,21 @@ class TTSServerExecutor(TTSExecutor):
# voc model info
voc_tag = voc + '-' + lang
self.task_resource.set_task_model(
model_tag=voc_tag,
model_type=1, # vocoder
version=None, # default version
)
if voc_ckpt is None or voc_config is None or voc_stat is None:
voc_res_path = self._get_pretrained_path(voc_tag)
self.voc_res_path = voc_res_path
self.voc_res_path = self.task_resource.voc_res_dir
self.voc_config = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['config'])
self.voc_res_path, self.task_resource.voc_res_dict['config'])
self.voc_ckpt = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['ckpt'])
self.voc_res_path, self.task_resource.voc_res_dict['ckpt'])
self.voc_stat = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['speech_stats'])
logger.info(voc_res_path)
self.voc_res_path,
self.task_resource.voc_res_dict['speech_stats'])
logger.info(self.voc_res_path)
logger.info(self.voc_config)
logger.info(self.voc_ckpt)
else:
@ -188,8 +197,8 @@ class TTSServerExecutor(TTSExecutor):
am, am_mu, am_std = self.get_model_info("am", self.am_name,
self.am_ckpt, self.am_stat)
am_normalizer = ZScore(am_mu, am_std)
am_inference_class = dynamic_import(self.am_name + '_inference',
self.model_alias)
am_inference_class = self.task_resource.get_model_class(
self.am_name + '_inference')
self.am_inference = am_inference_class(am_normalizer, am)
self.am_inference.eval()
print("acoustic model done!")
@ -199,8 +208,8 @@ class TTSServerExecutor(TTSExecutor):
voc, voc_mu, voc_std = self.get_model_info("voc", self.voc_name,
self.voc_ckpt, self.voc_stat)
voc_normalizer = ZScore(voc_mu, voc_std)
voc_inference_class = dynamic_import(self.voc_name + '_inference',
self.model_alias)
voc_inference_class = self.task_resource.get_model_class(self.voc_name +
'_inference')
self.voc_inference = voc_inference_class(voc_normalizer, voc)
self.voc_inference.eval()
print("voc done!")
@ -505,4 +514,4 @@ class TTSEngine(BaseEngine):
logger.info(f"RTF: {self.executor.final_response_time / duration}")
logger.info(
f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s,"
)
)

@ -1,87 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Static model applied on paddle inference
pretrained_models = {
# speedyspeech
"speedyspeech_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip',
'md5':
'f10cbdedf47dc7a9668d2264494e1823',
'model':
'speedyspeech_csmsc.pdmodel',
'params':
'speedyspeech_csmsc.pdiparams',
'phones_dict':
'phone_id_map.txt',
'tones_dict':
'tone_id_map.txt',
'sample_rate':
24000,
},
# fastspeech2
"fastspeech2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip',
'md5':
'9788cd9745e14c7a5d12d32670b2a5a7',
'model':
'fastspeech2_csmsc.pdmodel',
'params':
'fastspeech2_csmsc.pdiparams',
'phones_dict':
'phone_id_map.txt',
'sample_rate':
24000,
},
# pwgan
"pwgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_static_0.4.zip',
'md5':
'e3504aed9c5a290be12d1347836d2742',
'model':
'pwgan_csmsc.pdmodel',
'params':
'pwgan_csmsc.pdiparams',
'sample_rate':
24000,
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_static_0.1.1.zip',
'md5':
'ac6eee94ba483421d750433f4c3b8d36',
'model':
'mb_melgan_csmsc.pdmodel',
'params':
'mb_melgan_csmsc.pdiparams',
'sample_rate':
24000,
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_static_0.1.1.zip',
'md5':
'7edd8c436b3a5546b3a7cb8cff9d5a0c',
'model':
'hifigan_csmsc.pdmodel',
'params':
'hifigan_csmsc.pdiparams',
'sample_rate':
24000,
},
}

@ -23,9 +23,9 @@ import paddle
import soundfile as sf
from scipy.io import wavfile
from .pretrained_models import pretrained_models
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.resource import CommonTaskResource
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import change_speed
from paddlespeech.server.utils.errors import ErrorCode
@ -41,7 +41,8 @@ __all__ = ['TTSEngine']
class TTSServerExecutor(TTSExecutor):
def __init__(self):
super().__init__()
self.pretrained_models = pretrained_models
self.task_resource = CommonTaskResource(
task='tts', model_format='static')
def _init_from_path(
self,
@ -67,19 +68,23 @@ class TTSServerExecutor(TTSExecutor):
return
# am
am_tag = am + '-' + lang
self.task_resource.set_task_model(
model_tag=am_tag,
model_type=0, # am
version=None, # default version
)
if am_model is None or am_params is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag)
self.am_res_path = am_res_path
self.am_model = os.path.join(
am_res_path, self.pretrained_models[am_tag]['model'])
self.am_params = os.path.join(
am_res_path, self.pretrained_models[am_tag]['params'])
self.am_res_path = self.task_resource.res_dir
self.am_model = os.path.join(self.am_res_path,
self.task_resource.res_dict['model'])
self.am_params = os.path.join(self.am_res_path,
self.task_resource.res_dict['params'])
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['phones_dict'])
self.am_sample_rate = self.pretrained_models[am_tag]['sample_rate']
self.am_res_path, self.task_resource.res_dict['phones_dict'])
self.am_sample_rate = self.task_resource.res_dict['sample_rate']
logger.info(am_res_path)
logger.info(self.am_res_path)
logger.info(self.am_model)
logger.info(self.am_params)
else:
@ -92,32 +97,36 @@ class TTSServerExecutor(TTSExecutor):
# for speedyspeech
self.tones_dict = None
if 'tones_dict' in self.pretrained_models[am_tag]:
if 'tones_dict' in self.task_resource.res_dict:
self.tones_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['tones_dict'])
self.am_res_path, self.task_resource.res_dict['tones_dict'])
if tones_dict:
self.tones_dict = tones_dict
# for multi speaker fastspeech2
self.speaker_dict = None
if 'speaker_dict' in self.pretrained_models[am_tag]:
if 'speaker_dict' in self.task_resource.res_dict:
self.speaker_dict = os.path.join(
am_res_path, self.pretrained_models[am_tag]['speaker_dict'])
self.am_res_path, self.task_resource.res_dict['speaker_dict'])
if speaker_dict:
self.speaker_dict = speaker_dict
# voc
voc_tag = voc + '-' + lang
self.task_resource.set_task_model(
model_tag=voc_tag,
model_type=1, # vocoder
version=None, # default version
)
if voc_model is None or voc_params is None:
voc_res_path = self._get_pretrained_path(voc_tag)
self.voc_res_path = voc_res_path
self.voc_res_path = self.task_resource.voc_res_dir
self.voc_model = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['model'])
self.voc_res_path, self.task_resource.voc_res_dict['model'])
self.voc_params = os.path.join(
voc_res_path, self.pretrained_models[voc_tag]['params'])
self.voc_sample_rate = self.pretrained_models[voc_tag][
self.voc_res_path, self.task_resource.voc_res_dict['params'])
self.voc_sample_rate = self.task_resource.voc_res_dict[
'sample_rate']
logger.info(voc_res_path)
logger.info(self.voc_res_path)
logger.info(self.voc_model)
logger.info(self.voc_params)
else:

Loading…
Cancel
Save