From 2189b46004376a1cae4e1d67f114d599f8c243d5 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 9 Dec 2021 09:35:08 +0000 Subject: [PATCH] add tts cli --- dataset/aidatatang_200zh/aidatatang_200zh.py | 1 - dataset/aishell/aishell.py | 1 - dataset/librispeech/librispeech.py | 1 - dataset/mini_librispeech/mini_librispeech.py | 1 - dataset/musan/musan.py | 1 - dataset/rir_noise/rir_noise.py | 1 - dataset/thchs30/thchs30.py | 1 - dataset/timit/timit.py | 1 - dataset/voxforge/voxforge.py | 1 - paddlespeech/cli/__init__.py | 1 + paddlespeech/cli/asr/infer.py | 12 +- paddlespeech/cli/cls/infer.py | 4 +- paddlespeech/cli/download.py | 376 ++++++++++ paddlespeech/cli/tts/__init.__py | 0 paddlespeech/cli/tts/__init__.py | 14 + paddlespeech/cli/tts/infer.py | 641 ++++++++++++++++++ paddlespeech/cli/utils.py | 6 +- paddlespeech/cls/exps/panns/deploy/predict.py | 3 +- paddlespeech/cls/exps/panns/export_model.py | 2 +- paddlespeech/cls/exps/panns/predict.py | 2 +- paddlespeech/cls/exps/panns/train.py | 2 +- paddlespeech/cls/models/panns/panns.py | 1 - .../s2t/exps/deepspeech2/bin/__init__.py | 13 + .../exps/deepspeech2/bin/deploy/__init__.py | 13 + .../s2t/exps/lm/transformer/bin/__init__.py | 13 + paddlespeech/s2t/exps/u2/bin/__init__.py | 13 + .../s2t/exps/u2_kaldi/bin/__init__.py | 13 + paddlespeech/s2t/exps/u2_st/bin/__init__.py | 13 + paddlespeech/s2t/exps/u2_st/model.py | 2 - .../multi_spk_synthesize_e2e_en.py | 15 +- .../t2s/exps/fastspeech2/synthesize_e2e_en.py | 15 +- paddlespeech/t2s/frontend/phonectic.py | 49 +- utils/manifest_key_value.py | 1 - 33 files changed, 1154 insertions(+), 79 deletions(-) create mode 100644 paddlespeech/cli/download.py delete mode 100644 paddlespeech/cli/tts/__init.__py create mode 100644 paddlespeech/cli/tts/__init__.py create mode 100644 paddlespeech/cli/tts/infer.py diff --git a/dataset/aidatatang_200zh/aidatatang_200zh.py b/dataset/aidatatang_200zh/aidatatang_200zh.py index 85f478c20..b8758c9a7 100644 --- a/dataset/aidatatang_200zh/aidatatang_200zh.py +++ b/dataset/aidatatang_200zh/aidatatang_200zh.py @@ -25,7 +25,6 @@ import os from pathlib import Path import soundfile - from utils.utility import download from utils.utility import unpack diff --git a/dataset/aishell/aishell.py b/dataset/aishell/aishell.py index 7431fc083..32dc119d2 100644 --- a/dataset/aishell/aishell.py +++ b/dataset/aishell/aishell.py @@ -25,7 +25,6 @@ import os from pathlib import Path import soundfile - from utils.utility import download from utils.utility import unpack diff --git a/dataset/librispeech/librispeech.py b/dataset/librispeech/librispeech.py index 69f0db599..0c779696d 100644 --- a/dataset/librispeech/librispeech.py +++ b/dataset/librispeech/librispeech.py @@ -27,7 +27,6 @@ import os from multiprocessing.pool import Pool import soundfile - from utils.utility import download from utils.utility import unpack diff --git a/dataset/mini_librispeech/mini_librispeech.py b/dataset/mini_librispeech/mini_librispeech.py index 730c73a8b..d96b5d64d 100644 --- a/dataset/mini_librispeech/mini_librispeech.py +++ b/dataset/mini_librispeech/mini_librispeech.py @@ -26,7 +26,6 @@ import os from multiprocessing.pool import Pool import soundfile - from utils.utility import download from utils.utility import unpack diff --git a/dataset/musan/musan.py b/dataset/musan/musan.py index 2ac701bed..dc237c30a 100644 --- a/dataset/musan/musan.py +++ b/dataset/musan/musan.py @@ -28,7 +28,6 @@ import json import os import soundfile - from utils.utility import download from utils.utility import unpack diff --git a/dataset/rir_noise/rir_noise.py b/dataset/rir_noise/rir_noise.py index e7b122890..0e055f17b 100644 --- a/dataset/rir_noise/rir_noise.py +++ b/dataset/rir_noise/rir_noise.py @@ -28,7 +28,6 @@ import json import os import soundfile - from utils.utility import download from utils.utility import unzip diff --git a/dataset/thchs30/thchs30.py b/dataset/thchs30/thchs30.py index cdfc0a75c..879ed58db 100644 --- a/dataset/thchs30/thchs30.py +++ b/dataset/thchs30/thchs30.py @@ -26,7 +26,6 @@ from multiprocessing.pool import Pool from pathlib import Path import soundfile - from utils.utility import download from utils.utility import unpack diff --git a/dataset/timit/timit.py b/dataset/timit/timit.py index c4a9f0663..d03c48a1e 100644 --- a/dataset/timit/timit.py +++ b/dataset/timit/timit.py @@ -27,7 +27,6 @@ import string from pathlib import Path import soundfile - from utils.utility import unzip URL_ROOT = "" diff --git a/dataset/voxforge/voxforge.py b/dataset/voxforge/voxforge.py index 373791bff..c388f4491 100644 --- a/dataset/voxforge/voxforge.py +++ b/dataset/voxforge/voxforge.py @@ -27,7 +27,6 @@ import shutil import subprocess import soundfile - from utils.utility import download_multi from utils.utility import getfile_insensitive from utils.utility import unpack diff --git a/paddlespeech/cli/__init__.py b/paddlespeech/cli/__init__.py index 246d0f381..99a53c37e 100644 --- a/paddlespeech/cli/__init__.py +++ b/paddlespeech/cli/__init__.py @@ -16,3 +16,4 @@ from .base_commands import BaseCommand from .base_commands import HelpCommand from .cls import CLSExecutor from .st import STExecutor +from .tts import TTSExecutor diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 1e59f015a..2db239c0d 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -119,7 +119,7 @@ class ASRExecutor(BaseExecutor): def _get_pretrained_path(self, tag: str) -> os.PathLike: """ - Download and returns pretrained resources path of current task. + Download and returns pretrained resources path of current task. """ assert tag in pretrained_models, 'Can not find pretrained resources of {}.'.format( tag) @@ -140,7 +140,7 @@ class ASRExecutor(BaseExecutor): cfg_path: Optional[os.PathLike]=None, ckpt_path: Optional[os.PathLike]=None): """ - Init model and other resources from a specific path. + Init model and other resources from a specific path. """ if hasattr(self, 'model'): logger.info('Model had been initialized.') @@ -216,8 +216,8 @@ class ASRExecutor(BaseExecutor): def preprocess(self, model_type: str, input: Union[str, os.PathLike]): """ - Input preprocess and return paddle.Tensor stored in self.input. - Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). + Input preprocess and return paddle.Tensor stored in self.input. + Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). """ audio_file = input @@ -291,7 +291,7 @@ class ASRExecutor(BaseExecutor): @paddle.no_grad() def infer(self, model_type: str): """ - Model inference and result stored in self.output. + Model inference and result stored in self.output. """ text_feature = TextFeaturizer( unit_type=self.config.collator.unit_type, @@ -438,7 +438,7 @@ class ASRExecutor(BaseExecutor): def __call__(self, model, lang, sample_rate, config, ckpt_path, audio_file, device): """ - Python API to call an executor. + Python API to call an executor. """ audio_file = os.path.abspath(audio_file) self._check(audio_file, sample_rate) diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index b73d16679..795d59f68 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -20,14 +20,14 @@ from typing import Union import numpy as np import paddle import yaml +from paddleaudio import load +from paddleaudio.features import LogMelSpectrogram from ..executor import BaseExecutor from ..utils import cli_register from ..utils import download_and_decompress from ..utils import logger from ..utils import MODEL_HOME -from paddleaudio import load -from paddleaudio.features import LogMelSpectrogram from paddlespeech.s2t.utils.dynamic_import import dynamic_import __all__ = ['CLSExecutor'] diff --git a/paddlespeech/cli/download.py b/paddlespeech/cli/download.py new file mode 100644 index 000000000..8de9f0459 --- /dev/null +++ b/paddlespeech/cli/download.py @@ -0,0 +1,376 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib +import os +import os.path as osp +import shutil +import subprocess +import sys +import tarfile +import time +import zipfile + +import requests + +try: + from tqdm import tqdm +except: + + class tqdm(object): + def __init__(self, total=None): + self.total = total + self.n = 0 + + def update(self, n): + self.n += n + if self.total is None: + sys.stderr.write("\r{0:.1f} bytes".format(self.n)) + else: + sys.stderr.write( + "\r{0:.1f}%".format(100 * self.n / float(self.total))) + sys.stderr.flush() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stderr.write('\n') + + +import logging +logger = logging.getLogger(__name__) + +__all__ = ['get_weights_path_from_url'] + +WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights") + +DOWNLOAD_RETRY_LIMIT = 3 + + +def is_url(path): + """ + Whether path is URL. + Args: + path (string): URL string or not. + """ + return path.startswith('http://') or path.startswith('https://') + + +def get_weights_path_from_url(url, md5sum=None): + """Get weights path from WEIGHT_HOME, if not exists, + download it from url. + Args: + url (str): download url + md5sum (str): md5 sum of download package + + Returns: + str: a local path to save downloaded weights. + Examples: + .. code-block:: python + from paddle.utils.download import get_weights_path_from_url + resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams' + local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url) + """ + path = get_path_from_url(url, WEIGHTS_HOME, md5sum) + return path + + +def _map_path(url, root_dir): + # parse path after download under root_dir + fname = osp.split(url)[-1] + fpath = fname + return osp.join(root_dir, fpath) + + +def _get_unique_endpoints(trainer_endpoints): + # Sorting is to avoid different environmental variables for each card + trainer_endpoints.sort() + ips = set() + unique_endpoints = set() + for endpoint in trainer_endpoints: + ip = endpoint.split(":")[0] + if ip in ips: + continue + ips.add(ip) + unique_endpoints.add(endpoint) + logger.info("unique_endpoints {}".format(unique_endpoints)) + return unique_endpoints + + +def get_path_from_url(url, + root_dir, + md5sum=None, + check_exist=True, + decompress=True, + method='get'): + """ Download from given url to root_dir. + if file or directory specified by url is exists under + root_dir, return the path directly, otherwise download + from url and decompress it, return the path. + Args: + url (str): download url + root_dir (str): root dir for downloading, it should be + WEIGHTS_HOME or DATASET_HOME + md5sum (str): md5 sum of download package + decompress (bool): decompress zip or tar file. Default is `True` + method (str): which download method to use. Support `wget` and `get`. Default is `get`. + Returns: + str: a local path to save downloaded models & weights & datasets. + """ + + from paddle.fluid.dygraph.parallel import ParallelEnv + + assert is_url(url), "downloading from {} not a url".format(url) + # parse path after download to decompress under root_dir + fullpath = _map_path(url, root_dir) + # Mainly used to solve the problem of downloading data from different + # machines in the case of multiple machines. Different ips will download + # data, and the same ip will only download data once. + unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:]) + if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum): + logger.info("Found {}".format(fullpath)) + else: + if ParallelEnv().current_endpoint in unique_endpoints: + fullpath = _download(url, root_dir, md5sum, method=method) + else: + while not os.path.exists(fullpath): + time.sleep(1) + + if ParallelEnv().current_endpoint in unique_endpoints: + if decompress and (tarfile.is_tarfile(fullpath) or + zipfile.is_zipfile(fullpath)): + fullpath = _decompress(fullpath) + + return fullpath + + +def _get_download(url, fullname): + # using requests.get method + fname = osp.basename(fullname) + try: + req = requests.get(url, stream=True) + except Exception as e: # requests.exceptions.ConnectionError + logger.info("Downloading {} from {} failed with exception {}".format( + fname, url, str(e))) + return False + + if req.status_code != 200: + raise RuntimeError("Downloading from {} failed with code " + "{}!".format(url, req.status_code)) + + # For protecting download interupted, download to + # tmp_fullname firstly, move tmp_fullname to fullname + # after download finished + tmp_fullname = fullname + "_tmp" + total_size = req.headers.get('content-length') + with open(tmp_fullname, 'wb') as f: + if total_size: + with tqdm(total=(int(total_size) + 1023) // 1024) as pbar: + for chunk in req.iter_content(chunk_size=1024): + f.write(chunk) + pbar.update(1) + else: + for chunk in req.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + shutil.move(tmp_fullname, fullname) + + return fullname + + +def _wget_download(url, fullname): + # using wget to download url + tmp_fullname = fullname + "_tmp" + # –user-agent + command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT, + url) + subprc = subprocess.Popen( + command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + _ = subprc.communicate() + + if subprc.returncode != 0: + raise RuntimeError( + '{} failed. Please make sure `wget` is installed or {} exists'. + format(command, url)) + + shutil.move(tmp_fullname, fullname) + + return fullname + + +_download_methods = { + 'get': _get_download, + 'wget': _wget_download, +} + + +def _download(url, path, md5sum=None, method='get'): + """ + Download from url, save to path. + url (str): download url + path (str): download to given path + md5sum (str): md5 sum of download package + method (str): which download method to use. Support `wget` and `get`. Default is `get`. + """ + assert method in _download_methods, 'make sure `{}` implemented'.format( + method) + + if not osp.exists(path): + os.makedirs(path) + + fname = osp.split(url)[-1] + fullname = osp.join(path, fname) + retry_cnt = 0 + + logger.info("Downloading {} from {}".format(fname, url)) + while not (osp.exists(fullname) and _md5check(fullname, md5sum)): + if retry_cnt < DOWNLOAD_RETRY_LIMIT: + retry_cnt += 1 + else: + raise RuntimeError("Download from {} failed. " + "Retry limit reached".format(url)) + + if not _download_methods[method](url, fullname): + time.sleep(1) + continue + + return fullname + + +def _md5check(fullname, md5sum=None): + if md5sum is None: + return True + + logger.info("File {} md5 checking...".format(fullname)) + md5 = hashlib.md5() + with open(fullname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b""): + md5.update(chunk) + calc_md5sum = md5.hexdigest() + + if calc_md5sum != md5sum: + logger.info("File {} md5 check failed, {}(calc) != " + "{}(base)".format(fullname, calc_md5sum, md5sum)) + return False + return True + + +def _decompress(fname): + """ + Decompress for zip and tar file + """ + logger.info("Decompressing {}...".format(fname)) + + # For protecting decompressing interupted, + # decompress to fpath_tmp directory firstly, if decompress + # successed, move decompress files to fpath and delete + # fpath_tmp and remove download compress file. + + if tarfile.is_tarfile(fname): + uncompressed_path = _uncompress_file_tar(fname) + elif zipfile.is_zipfile(fname): + uncompressed_path = _uncompress_file_zip(fname) + else: + raise TypeError("Unsupport compress file type {}".format(fname)) + + return uncompressed_path + + +def _uncompress_file_zip(filepath): + files = zipfile.ZipFile(filepath, 'r') + file_list = files.namelist() + + file_dir = os.path.dirname(filepath) + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + + for item in file_list: + files.extract(item, file_dir) + + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0] + uncompressed_path = os.path.join(file_dir, rootpath) + + for item in file_list: + files.extract(item, file_dir) + + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) + for item in file_list: + files.extract(item, os.path.join(file_dir, rootpath)) + + files.close() + + return uncompressed_path + + +def _uncompress_file_tar(filepath, mode="r:*"): + files = tarfile.open(filepath, mode) + file_list = files.getnames() + + file_dir = os.path.dirname(filepath) + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + for item in file_list: + files.extract(item, file_dir) + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + for item in file_list: + files.extract(item, file_dir) + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) + + for item in file_list: + files.extract(item, os.path.join(file_dir, rootpath)) + + files.close() + + return uncompressed_path + + +def _is_a_single_file(file_list): + if len(file_list) == 1 and file_list[0].find(os.sep) < -1: + return True + return False + + +def _is_a_single_dir(file_list): + new_file_list = [] + for file_path in file_list: + if '/' in file_path: + file_path = file_path.replace('/', os.sep) + elif '\\' in file_path: + file_path = file_path.replace('\\', os.sep) + new_file_list.append(file_path) + + file_name = new_file_list[0].split(os.sep)[0] + for i in range(1, len(new_file_list)): + if file_name != new_file_list[i].split(os.sep)[0]: + return False + return True diff --git a/paddlespeech/cli/tts/__init.__py b/paddlespeech/cli/tts/__init.__py deleted file mode 100644 index e69de29bb..000000000 diff --git a/paddlespeech/cli/tts/__init__.py b/paddlespeech/cli/tts/__init__.py new file mode 100644 index 000000000..4cc3c42fc --- /dev/null +++ b/paddlespeech/cli/tts/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .infer import TTSExecutor diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py new file mode 100644 index 000000000..771b7d6dc --- /dev/null +++ b/paddlespeech/cli/tts/infer.py @@ -0,0 +1,641 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from typing import Any +from typing import List +from typing import Optional +from typing import Union + +import numpy as np +import paddle +import soundfile as sf +import yaml +from yacs.config import CfgNode + +from ..executor import BaseExecutor +from ..utils import cli_register +from ..utils import download_and_decompress +from ..utils import logger +from ..utils import MODEL_HOME +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.zh_frontend import Frontend +from paddlespeech.t2s.modules.normalizer import ZScore + +__all__ = ['TTSExecutor'] + +pretrained_models = { + # speedyspeech + "speedyspeech_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip', + 'md5': + '9edce23b1a87f31b814d9477bf52afbc', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_11400.pdz', + 'speech_stats': + 'feats_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'tones_dict': + 'tone_id_map.txt', + }, + + # fastspeech2 + "fastspeech2_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', + 'md5': + '637d28a5e53aa60275612ba4393d5f22', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_76000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + "fastspeech2_ljspeech-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip', + 'md5': + 'ffed800c93deaf16ca9b3af89bfcd747', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_100000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + "fastspeech2_aishell3-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip', + 'md5': + 'f4dd4a5f49a4552b77981f544ab3392e', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_96400.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'speaker_dict': + 'speaker_id_map.txt', + }, + "fastspeech2_vctk-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip', + 'md5': + '743e5024ca1e17a88c5c271db9779ba4', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_66200.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + 'speaker_dict': + 'speaker_id_map.txt', + }, + # pwgan + "pwgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip', + 'md5': + '2e481633325b5bdf0a3823c714d2c117', + 'config': + 'pwg_default.yaml', + 'ckpt': + 'pwg_snapshot_iter_400000.pdz', + 'speech_stats': + 'pwg_stats.npy', + }, + "pwgan_ljspeech-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip', + 'md5': + '53610ba9708fd3008ccaf8e99dacbaf0', + 'config': + 'pwg_default.yaml', + 'ckpt': + 'pwg_snapshot_iter_400000.pdz', + 'speech_stats': + 'pwg_stats.npy', + }, + "pwgan_aishell3-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip', + 'md5': + 'd7598fa41ad362d62f85ffc0f07e3d84', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1000000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + "pwgan_vctk-en": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.5.zip', + 'md5': + '322ca688aec9b127cec2788b65aa3d52', + 'config': + 'pwg_default.yaml', + 'ckpt': + 'pwg_snapshot_iter_1000000.pdz', + 'speech_stats': + 'pwg_stats.npy', + }, + # mb_melgan + "mb_melgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_baker_finetune_ckpt_0.5.zip', + 'md5': + 'b69322ab4ea766d955bd3d9af7dc5f2d', + 'config': + 'finetune.yaml', + 'ckpt': + 'snapshot_iter_2000000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, +} + +model_alias = { + # acoustic model + "speedyspeech": + "paddlespeech.t2s.models.speedyspeech:SpeedySpeech", + "speedyspeech_inference": + "paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference", + "fastspeech2": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2", + "fastspeech2_inference": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", + # voc + "pwgan": + "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", + "pwgan_inference": + "paddlespeech.t2s.models.parallel_wavegan:PWGInference", + "mb_melgan": + "paddlespeech.t2s.models.melgan:MelGANGenerator", + "mb_melgan_inference": + "paddlespeech.t2s.models.melgan:MelGANInference", +} + + +@cli_register( + name='paddlespeech.tts', description='Text to Speech infer command.') +class TTSExecutor(BaseExecutor): + def __init__(self): + super().__init__() + + self.parser = argparse.ArgumentParser( + prog='paddlespeech.tts', add_help=True) + self.parser.add_argument( + '--input', type=str, required=True, help='Input text to generate.') + # acoustic model + self.parser.add_argument( + '--am', + type=str, + default='fastspeech2_csmsc', + choices=[ + 'speedyspeech_csmsc', 'fastspeech2_csmsc', + 'fastspeech2_ljspeech', 'fastspeech2_aishell3', + 'fastspeech2_vctk' + ], + help='Choose acoustic model type of tts task.') + self.parser.add_argument( + '--am_config', + type=str, + default=None, + help='Config of acoustic model. Use deault config when it is None.') + self.parser.add_argument( + '--am_ckpt', + type=str, + default=None, + help='Checkpoint file of acoustic model.') + self.parser.add_argument( + "--am_stat", + type=str, + help="mean and standard deviation used to normalize spectrogram when training acoustic model." + ) + self.parser.add_argument( + "--phones_dict", + type=str, + default=None, + help="phone vocabulary file.") + self.parser.add_argument( + "--tones_dict", + type=str, + default=None, + help="tone vocabulary file.") + self.parser.add_argument( + "--speaker_dict", + type=str, + default=None, + help="speaker id map file.") + self.parser.add_argument( + '--spk_id', + type=int, + default=0, + help='spk id for multi speaker acoustic model') + # vocoder + self.parser.add_argument( + '--voc', + type=str, + default='pwgan_csmsc', + choices=[ + 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk', + 'mb_melgan_csmsc' + ], + help='Choose vocoder type of tts task.') + + self.parser.add_argument( + '--voc_config', + type=str, + default=None, + help='Config of voc. Use deault config when it is None.') + self.parser.add_argument( + '--voc_ckpt', + type=str, + default=None, + help='Checkpoint file of voc.') + self.parser.add_argument( + "--voc_stat", + type=str, + help="mean and standard deviation used to normalize spectrogram when training voc." + ) + # other + self.parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') + self.parser.add_argument( + '--device', + type=str, + default=paddle.get_device(), + help='Choose device to execute model inference.') + + self.parser.add_argument( + '--output', type=str, default='output.wav', help='output file name') + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + assert tag in pretrained_models, 'Can not find pretrained resources of {}.'.format( + tag) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + return decompressed_path + + def _init_from_path( + self, + am: str='fastspeech2_csmsc', + am_config: Optional[os.PathLike]=None, + am_ckpt: Optional[os.PathLike]=None, + am_stat: Optional[os.PathLike]=None, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, + voc: str='pwgan_csmsc', + voc_config: Optional[os.PathLike]=None, + voc_ckpt: Optional[os.PathLike]=None, + voc_stat: Optional[os.PathLike]=None, + lang: str='zh', ): + """ + Init model and other resources from a specific path. + """ + if hasattr(self, 'am') and hasattr(self, 'voc'): + logger.info('Models had been initialized.') + return + # am + am_tag = am + '-' + lang + if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: + am_res_path = self._get_pretrained_path(am_tag) + self.am_res_path = am_res_path + self.am_config = os.path.join(am_res_path, + pretrained_models[am_tag]['config']) + self.am_ckpt = os.path.join(am_res_path, + pretrained_models[am_tag]['ckpt']) + self.am_stat = os.path.join( + am_res_path, pretrained_models[am_tag]['speech_stats']) + # must have phones_dict in acoustic + self.phones_dict = os.path.join( + am_res_path, pretrained_models[am_tag]['phones_dict']) + print("self.phones_dict:", self.phones_dict) + logger.info(am_res_path) + logger.info(self.am_config) + logger.info(self.am_ckpt) + else: + self.am_config = os.path.abspath(am_config) + self.am_ckpt = os.path.abspath(am_ckpt) + self.am_stat = os.path.abspath(am_stat) + self.phones_dict = os.path.abspath(phones_dict) + self.am_res_path = os.path.dirname(os.path.abspath(self.am_config)) + print("self.phones_dict:", self.phones_dict) + + # for speedyspeech + self.tones_dict = None + if 'tones_dict' in pretrained_models[am_tag]: + self.tones_dict = os.path.join( + am_res_path, pretrained_models[am_tag]['tones_dict']) + if tones_dict: + self.tones_dict = tones_dict + + # for multi speaker fastspeech2 + self.speaker_dict = None + if 'speaker_dict' in pretrained_models[am_tag]: + self.speaker_dict = os.path.join( + am_res_path, pretrained_models[am_tag]['speaker_dict']) + if speaker_dict: + self.speaker_dict = speaker_dict + + # voc + voc_tag = voc + '-' + lang + if voc_ckpt is None or voc_config is None or voc_stat is None: + voc_res_path = self._get_pretrained_path(voc_tag) + self.voc_res_path = voc_res_path + self.voc_config = os.path.join(voc_res_path, + pretrained_models[voc_tag]['config']) + self.voc_ckpt = os.path.join(voc_res_path, + pretrained_models[voc_tag]['ckpt']) + self.voc_stat = os.path.join( + voc_res_path, pretrained_models[voc_tag]['speech_stats']) + logger.info(voc_res_path) + logger.info(self.voc_config) + logger.info(self.voc_ckpt) + else: + self.voc_config = os.path.abspath(voc_config) + self.voc_ckpt = os.path.abspath(voc_ckpt) + self.voc_stat = os.path.abspath(voc_stat) + self.voc_res_path = os.path.dirname( + os.path.abspath(self.voc_config)) + + # Init body. + with open(self.am_config) as f: + self.am_config = CfgNode(yaml.safe_load(f)) + with open(self.voc_config) as f: + self.voc_config = CfgNode(yaml.safe_load(f)) + + # Enter the path of model root + + with open(self.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + tone_size = None + if self.tones_dict: + with open(self.tones_dict, "r") as f: + tone_id = [line.strip().split() for line in f.readlines()] + tone_size = len(tone_id) + print("tone_size:", tone_size) + + spk_num = None + if self.speaker_dict: + with open(self.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + print("spk_num:", spk_num) + + # frontend + if lang == 'zh': + self.frontend = Frontend( + phone_vocab_path=self.phones_dict, + tone_vocab_path=self.tones_dict) + + elif lang == 'en': + self.frontend = English(phone_vocab_path=self.phones_dict) + print("frontend done!") + + # acoustic model + odim = self.am_config.n_mels + # model: {model_name}_{dataset} + am_name = am[:am.rindex('_')] + + am_class = dynamic_import(am_name, model_alias) + am_inference_class = dynamic_import(am_name + '_inference', model_alias) + + if am_name == 'fastspeech2': + am = am_class( + idim=vocab_size, + odim=odim, + spk_num=spk_num, + **self.am_config["model"]) + elif am_name == 'speedyspeech': + am = am_class( + vocab_size=vocab_size, + tone_size=tone_size, + **self.am_config["model"]) + + am.set_state_dict(paddle.load(self.am_ckpt)["main_params"]) + am.eval() + am_mu, am_std = np.load(self.am_stat) + am_mu = paddle.to_tensor(am_mu) + am_std = paddle.to_tensor(am_std) + am_normalizer = ZScore(am_mu, am_std) + self.am_inference = am_inference_class(am_normalizer, am) + print("acoustic model done!") + + # vocoder + # model: {model_name}_{dataset} + voc_name = '_'.join(voc.split('_')[:-1]) + voc_class = dynamic_import(voc_name, model_alias) + voc_inference_class = dynamic_import(voc_name + '_inference', + model_alias) + voc = voc_class(**self.voc_config["generator_params"]) + voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"]) + voc.remove_weight_norm() + voc.eval() + voc_mu, voc_std = np.load(self.voc_stat) + voc_mu = paddle.to_tensor(voc_mu) + voc_std = paddle.to_tensor(voc_std) + voc_normalizer = ZScore(voc_mu, voc_std) + self.voc_inference = voc_inference_class(voc_normalizer, voc) + print("voc done!") + + def preprocess(self, input: Any, *args, **kwargs): + """ + Input preprocess and return paddle.Tensor stored in self._inputs. + Input content can be a text(tts), a file(asr, cls), a stream(not supported yet) or anything needed. + + Args: + input (Any): Input text/file/stream or other content. + """ + pass + + @paddle.no_grad() + def infer(self, + text: str, + lang: str='zh', + am: str='fastspeech2_csmsc', + spk_id: int=0): + """ + Model inference and result stored in self.output. + """ + model_name = am[:am.rindex('_')] + dataset = am[am.rindex('_') + 1:] + get_tone_ids = False + if 'speedyspeech' in model_name: + get_tone_ids = True + if lang == 'zh': + input_ids = self.frontend.get_input_ids( + text, merge_sentences=True, get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + phone_ids = phone_ids[0] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + tone_ids = tone_ids[0] + elif lang == 'en': + input_ids = self.frontend.get_input_ids(text) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + + # am + if 'speedyspeech' in model_name: + mel = self.am_inference(phone_ids, tone_ids) + # fastspeech2 + else: + # multi speaker + if dataset in {"aishell3", "vctk"}: + mel = self.am_inference( + phone_ids, spk_id=paddle.to_tensor(spk_id)) + + else: + mel = self.am_inference(phone_ids) + + # voc + wav = self.voc_inference(mel) + self._outputs['wav'] = wav + + def postprocess(self, output: str='output.wav'): + """ + Output postprocess and return results. + This method get model output from self._outputs and convert it into human-readable results. + + Returns: + Union[str, os.PathLike]: Human-readable results such as texts and audio files. + """ + sf.write( + output, self._outputs['wav'].numpy(), samplerate=self.am_config.fs) + return output + + def execute(self, argv: List[str]) -> bool: + """ + Command line entry. + """ + + args = self.parser.parse_args(argv) + + text = args.input + am = args.am + am_config = args.am_config + am_ckpt = args.am_ckpt + am_stat = args.am_stat + phones_dict = args.phones_dict + print("phones_dict:", phones_dict) + tones_dict = args.tones_dict + speaker_dict = args.speaker_dict + voc = args.voc + voc_config = args.voc_config + voc_ckpt = args.voc_ckpt + voc_stat = args.voc_stat + lang = args.lang + device = args.device + output = args.output + spk_id = args.spk_id + + try: + res = self( + text=text, + # acoustic model related + am=am, + am_config=am_config, + am_ckpt=am_ckpt, + am_stat=am_stat, + phones_dict=phones_dict, + tones_dict=tones_dict, + speaker_dict=speaker_dict, + spk_id=spk_id, + # vocoder related + voc=voc, + voc_config=voc_config, + voc_ckpt=voc_ckpt, + voc_stat=voc_stat, + # other + lang=lang, + device=device, + output=output) + logger.info('TTS Result Saved in: {}'.format(res)) + return True + except Exception as e: + logger.exception(e) + return False + + def __call__(self, + text: str, + am: str='fastspeech2_csmsc', + am_config: Optional[os.PathLike]=None, + am_ckpt: Optional[os.PathLike]=None, + am_stat: Optional[os.PathLike]=None, + spk_id: int=0, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, + voc: str='pwgan_csmsc', + voc_config: Optional[os.PathLike]=None, + voc_ckpt: Optional[os.PathLike]=None, + voc_stat: Optional[os.PathLike]=None, + lang: str='zh', + device: str='gpu', + output: str='output.wav'): + """ + Python API to call an executor. + """ + paddle.set_device(device) + self._init_from_path( + am=am, + am_config=am_config, + am_ckpt=am_ckpt, + am_stat=am_stat, + phones_dict=phones_dict, + tones_dict=tones_dict, + speaker_dict=speaker_dict, + voc=voc, + voc_config=voc_config, + voc_ckpt=voc_ckpt, + voc_stat=voc_stat, + lang=lang) + + self.infer(text=text, lang=lang, am=am, spk_id=spk_id) + + res = self.postprocess(output=output) + + return res diff --git a/paddlespeech/cli/utils.py b/paddlespeech/cli/utils.py index eb023c11b..6ae6e7e52 100644 --- a/paddlespeech/cli/utils.py +++ b/paddlespeech/cli/utils.py @@ -22,8 +22,8 @@ from typing import Dict from typing import List from paddle.framework import load -from paddle.utils import download +from . import download from .entry import commands __all__ = [ @@ -78,7 +78,6 @@ def _md5check(filepath: os.PathLike, md5sum: str) -> bool: def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike: file_dir = os.path.dirname(filepath) - if tarfile.is_tarfile(filepath): files = tarfile.open(filepath, "r:*") file_list = files.getnames() @@ -87,12 +86,11 @@ def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike: file_list = files.namelist() else: return file_dir - if _is_a_single_file(file_list): rootpath = file_list[0] uncompressed_path = os.path.join(file_dir, rootpath) elif _is_a_single_dir(file_list): - rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0] uncompressed_path = os.path.join(file_dir, rootpath) else: rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] diff --git a/paddlespeech/cls/exps/panns/deploy/predict.py b/paddlespeech/cls/exps/panns/deploy/predict.py index d4e5c22fb..ee566ed4f 100644 --- a/paddlespeech/cls/exps/panns/deploy/predict.py +++ b/paddlespeech/cls/exps/panns/deploy/predict.py @@ -16,11 +16,10 @@ import os import numpy as np from paddle import inference -from scipy.special import softmax - from paddleaudio.backends import load as load_audio from paddleaudio.datasets import ESC50 from paddleaudio.features import melspectrogram +from scipy.special import softmax # yapf: disable parser = argparse.ArgumentParser() diff --git a/paddlespeech/cls/exps/panns/export_model.py b/paddlespeech/cls/exps/panns/export_model.py index c295c6a33..63b22981a 100644 --- a/paddlespeech/cls/exps/panns/export_model.py +++ b/paddlespeech/cls/exps/panns/export_model.py @@ -15,8 +15,8 @@ import argparse import os import paddle - from paddleaudio.datasets import ESC50 + from paddlespeech.cls.models import cnn14 from paddlespeech.cls.models import SoundClassifier diff --git a/paddlespeech/cls/exps/panns/predict.py b/paddlespeech/cls/exps/panns/predict.py index 9cfd8b6ce..0a1b6cccf 100644 --- a/paddlespeech/cls/exps/panns/predict.py +++ b/paddlespeech/cls/exps/panns/predict.py @@ -16,11 +16,11 @@ import argparse import numpy as np import paddle import paddle.nn.functional as F - from paddleaudio.backends import load as load_audio from paddleaudio.datasets import ESC50 from paddleaudio.features import LogMelSpectrogram from paddleaudio.features import melspectrogram + from paddlespeech.cls.models import cnn14 from paddlespeech.cls.models import SoundClassifier diff --git a/paddlespeech/cls/exps/panns/train.py b/paddlespeech/cls/exps/panns/train.py index 121309789..9508a977e 100644 --- a/paddlespeech/cls/exps/panns/train.py +++ b/paddlespeech/cls/exps/panns/train.py @@ -15,11 +15,11 @@ import argparse import os import paddle - from paddleaudio.datasets import ESC50 from paddleaudio.features import LogMelSpectrogram from paddleaudio.utils import logger from paddleaudio.utils import Timer + from paddlespeech.cls.models import cnn14 from paddlespeech.cls.models import SoundClassifier diff --git a/paddlespeech/cls/models/panns/panns.py b/paddlespeech/cls/models/panns/panns.py index 6d2dac56a..b442b2fd1 100644 --- a/paddlespeech/cls/models/panns/panns.py +++ b/paddlespeech/cls/models/panns/panns.py @@ -15,7 +15,6 @@ import os import paddle.nn as nn import paddle.nn.functional as F - from paddleaudio.utils.download import load_state_dict_from_url from paddleaudio.utils.env import MODEL_HOME diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/__init__.py b/paddlespeech/s2t/exps/deepspeech2/bin/__init__.py index e69de29bb..185a92b8d 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/__init__.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/__init__.py b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/__init__.py index e69de29bb..185a92b8d 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/__init__.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/exps/lm/transformer/bin/__init__.py b/paddlespeech/s2t/exps/lm/transformer/bin/__init__.py index e69de29bb..185a92b8d 100644 --- a/paddlespeech/s2t/exps/lm/transformer/bin/__init__.py +++ b/paddlespeech/s2t/exps/lm/transformer/bin/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/exps/u2/bin/__init__.py b/paddlespeech/s2t/exps/u2/bin/__init__.py index e69de29bb..185a92b8d 100644 --- a/paddlespeech/s2t/exps/u2/bin/__init__.py +++ b/paddlespeech/s2t/exps/u2/bin/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/exps/u2_kaldi/bin/__init__.py b/paddlespeech/s2t/exps/u2_kaldi/bin/__init__.py index e69de29bb..185a92b8d 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/bin/__init__.py +++ b/paddlespeech/s2t/exps/u2_kaldi/bin/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/exps/u2_st/bin/__init__.py b/paddlespeech/s2t/exps/u2_st/bin/__init__.py index e69de29bb..185a92b8d 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/__init__.py +++ b/paddlespeech/s2t/exps/u2_st/bin/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index 144eed9d8..3ec2c920e 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -26,10 +26,8 @@ from paddle import distributed as dist from paddle.io import DataLoader from yacs.config import CfgNode -from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import TripletSpeechCollator -from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.io.dataset import ManifestDataset from paddlespeech.s2t.io.sampler import SortagradBatchSampler from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler diff --git a/paddlespeech/t2s/exps/fastspeech2/multi_spk_synthesize_e2e_en.py b/paddlespeech/t2s/exps/fastspeech2/multi_spk_synthesize_e2e_en.py index 095d20821..6a326e8a4 100644 --- a/paddlespeech/t2s/exps/fastspeech2/multi_spk_synthesize_e2e_en.py +++ b/paddlespeech/t2s/exps/fastspeech2/multi_spk_synthesize_e2e_en.py @@ -71,8 +71,7 @@ def evaluate(args, fastspeech2_config, pwg_config): vocoder.eval() print("model done!") - frontend = English() - punc = ":,;。?!“”‘’':,;.?!" + frontend = English(phone_vocab_path=args.phones_dict) print("frontend done!") stat = np.load(args.fastspeech2_stat) @@ -95,16 +94,8 @@ def evaluate(args, fastspeech2_config, pwg_config): # only test the number 0 speaker spk_id = 0 for utt_id, sentence in sentences: - phones = frontend.phoneticize(sentence) - # remove start_symbol and end_symbol - phones = phones[1:-1] - phones = [phn for phn in phones if not phn.isspace()] - phones = [ - phn if (phn in phone_id_map and phn not in punc) else "sp" - for phn in phones - ] - phone_ids = [phone_id_map[phn] for phn in phones] - phone_ids = paddle.to_tensor(phone_ids) + input_ids = frontend.get_input_ids(sentence) + phone_ids = input_ids["phone_ids"] with paddle.no_grad(): mel = fastspeech2_inference( diff --git a/paddlespeech/t2s/exps/fastspeech2/synthesize_e2e_en.py b/paddlespeech/t2s/exps/fastspeech2/synthesize_e2e_en.py index 6e3434a78..1b980afb1 100644 --- a/paddlespeech/t2s/exps/fastspeech2/synthesize_e2e_en.py +++ b/paddlespeech/t2s/exps/fastspeech2/synthesize_e2e_en.py @@ -63,8 +63,7 @@ def evaluate(args, fastspeech2_config, pwg_config): vocoder.eval() print("model done!") - frontend = English() - punc = ":,;。?!“”‘’':,;.?!" + frontend = English(phone_vocab_path=args.phones_dict) print("frontend done!") stat = np.load(args.fastspeech2_stat) @@ -86,16 +85,8 @@ def evaluate(args, fastspeech2_config, pwg_config): output_dir.mkdir(parents=True, exist_ok=True) for utt_id, sentence in sentences: - phones = frontend.phoneticize(sentence) - # remove start_symbol and end_symbol - phones = phones[1:-1] - phones = [phn for phn in phones if not phn.isspace()] - phones = [ - phn if (phn in phone_id_map and phn not in punc) else "sp" - for phn in phones - ] - phone_ids = [phone_id_map[phn] for phn in phones] - phone_ids = paddle.to_tensor(phone_ids) + input_ids = frontend.get_input_ids(sentence) + phone_ids = input_ids["phone_ids"] with paddle.no_grad(): mel = fastspeech2_inference(phone_ids) diff --git a/paddlespeech/t2s/frontend/phonectic.py b/paddlespeech/t2s/frontend/phonectic.py index 8eac0b48e..fbc8fd388 100644 --- a/paddlespeech/t2s/frontend/phonectic.py +++ b/paddlespeech/t2s/frontend/phonectic.py @@ -14,6 +14,7 @@ from abc import ABC from abc import abstractmethod +import paddle from g2p_en import G2p from g2pM import G2pM @@ -45,20 +46,25 @@ class English(Phonetics): """ Normalize the input text sequence and convert into pronunciation id sequence. """ - def __init__(self): + def __init__(self, phone_vocab_path=None): self.backend = G2p() self.phonemes = list(self.backend.phonemes) self.punctuations = get_punctuations("en") self.vocab = Vocab(self.phonemes + self.punctuations) + self.vocab_phones = {} + self.punc = ":,;。?!“”‘’':,;.?!" + if phone_vocab_path: + with open(phone_vocab_path, 'rt') as f: + phn_id = [line.strip().split() for line in f.readlines()] + for phn, id in phn_id: + self.vocab_phones[phn] = int(id) def phoneticize(self, sentence): """ Normalize the input text sequence and convert it into pronunciation sequence. - Parameters ----------- sentence: str The input text sequence. - Returns ---------- List[str] @@ -72,14 +78,27 @@ class English(Phonetics): phonemes = [item for item in phonemes if item in self.vocab.stoi] return phonemes + def get_input_ids(self, sentence: str) -> paddle.Tensor: + result = {} + phones = self.phoneticize(sentence) + # remove start_symbol and end_symbol + phones = phones[1:-1] + phones = [phn for phn in phones if not phn.isspace()] + phones = [ + phn if (phn in self.vocab_phones and phn not in self.punc) else "sp" + for phn in phones + ] + phone_ids = [self.vocab_phones[phn] for phn in phones] + phone_ids = paddle.to_tensor(phone_ids) + result["phone_ids"] = phone_ids + return result + def numericalize(self, phonemes): """ Convert pronunciation sequence into pronunciation id sequence. - Parameters ----------- phonemes: List[str] The list of pronunciation sequence. - Returns ---------- List[int] @@ -93,12 +112,10 @@ class English(Phonetics): def reverse(self, ids): """ Reverse the list of pronunciation id sequence to a list of pronunciation sequence. - Parameters ----------- ids: List[int] The list of pronunciation id sequence. - Returns ---------- List[str] @@ -108,12 +125,10 @@ class English(Phonetics): def __call__(self, sentence): """ Convert the input text sequence into pronunciation id sequence. - Parameters ----------- sentence: str The input text sequence. - Returns ---------- List[str] @@ -140,12 +155,10 @@ class EnglishCharacter(Phonetics): def phoneticize(self, sentence): """ Normalize the input text sequence. - Parameters ----------- sentence: str The input text sequence. - Returns ---------- str @@ -156,12 +169,10 @@ class EnglishCharacter(Phonetics): def numericalize(self, sentence): """ Convert a text sequence into ids. - Parameters ----------- sentence: str The input text sequence. - Returns ---------- List[int] @@ -175,17 +186,14 @@ class EnglishCharacter(Phonetics): def reverse(self, ids): """ Convert a character id sequence into text. - Parameters ----------- ids: List[int] List of a character id sequence. - Returns ---------- str The input text sequence. - """ return [self.vocab.reverse(i) for i in ids] @@ -195,7 +203,6 @@ class EnglishCharacter(Phonetics): ----------- sentence: str The input text sequence. - Returns ---------- List[int] @@ -229,12 +236,10 @@ class Chinese(Phonetics): def phoneticize(self, sentence): """ Normalize the input text sequence and convert it into pronunciation sequence. - Parameters ----------- sentence: str The input text sequence. - Returns ---------- List[str] @@ -263,12 +268,10 @@ class Chinese(Phonetics): def numericalize(self, phonemes): """ Convert pronunciation sequence into pronunciation id sequence. - Parameters ----------- phonemes: List[str] The list of pronunciation sequence. - Returns ---------- List[int] @@ -279,12 +282,10 @@ class Chinese(Phonetics): def __call__(self, sentence): """ Convert the input text sequence into pronunciation id sequence. - Parameters ----------- sentence: str The input text sequence. - Returns ---------- List[str] @@ -300,12 +301,10 @@ class Chinese(Phonetics): def reverse(self, ids): """ Reverse the list of pronunciation id sequence to a list of pronunciation sequence. - Parameters ----------- ids: List[int] The list of pronunciation id sequence. - Returns ---------- List[str] diff --git a/utils/manifest_key_value.py b/utils/manifest_key_value.py index fb3d3aaaf..3a8009039 100755 --- a/utils/manifest_key_value.py +++ b/utils/manifest_key_value.py @@ -5,7 +5,6 @@ import functools from pathlib import Path import jsonlines - from utils.utility import add_arguments from utils.utility import print_arguments