Merge pull request #1092 from yt605155624/tts_cli

[cli]add tts cli
pull/1093/head
Hui Zhang 3 years ago committed by GitHub
commit 12318566a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,7 +25,6 @@ import os
from pathlib import Path
import soundfile
from utils.utility import download
from utils.utility import unpack

@ -25,7 +25,6 @@ import os
from pathlib import Path
import soundfile
from utils.utility import download
from utils.utility import unpack

@ -27,7 +27,6 @@ import os
from multiprocessing.pool import Pool
import soundfile
from utils.utility import download
from utils.utility import unpack

@ -26,7 +26,6 @@ import os
from multiprocessing.pool import Pool
import soundfile
from utils.utility import download
from utils.utility import unpack

@ -28,7 +28,6 @@ import json
import os
import soundfile
from utils.utility import download
from utils.utility import unpack

@ -28,7 +28,6 @@ import json
import os
import soundfile
from utils.utility import download
from utils.utility import unzip

@ -26,7 +26,6 @@ from multiprocessing.pool import Pool
from pathlib import Path
import soundfile
from utils.utility import download
from utils.utility import unpack

@ -27,7 +27,6 @@ import string
from pathlib import Path
import soundfile
from utils.utility import unzip
URL_ROOT = ""

@ -27,7 +27,6 @@ import shutil
import subprocess
import soundfile
from utils.utility import download_multi
from utils.utility import getfile_insensitive
from utils.utility import unpack

@ -16,3 +16,4 @@ from .base_commands import BaseCommand
from .base_commands import HelpCommand
from .cls import CLSExecutor
from .st import STExecutor
from .tts import TTSExecutor

@ -20,14 +20,14 @@ from typing import Union
import numpy as np
import paddle
import yaml
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from ..executor import BaseExecutor
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import logger
from ..utils import MODEL_HOME
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
__all__ = ['CLSExecutor']

@ -0,0 +1,376 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import os
import os.path as osp
import shutil
import subprocess
import sys
import tarfile
import time
import zipfile
import requests
try:
from tqdm import tqdm
except:
class tqdm(object):
def __init__(self, total=None):
self.total = total
self.n = 0
def update(self, n):
self.n += n
if self.total is None:
sys.stderr.write("\r{0:.1f} bytes".format(self.n))
else:
sys.stderr.write(
"\r{0:.1f}%".format(100 * self.n / float(self.total)))
sys.stderr.flush()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stderr.write('\n')
import logging
logger = logging.getLogger(__name__)
__all__ = ['get_weights_path_from_url']
WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights")
DOWNLOAD_RETRY_LIMIT = 3
def is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def get_weights_path_from_url(url, md5sum=None):
"""Get weights path from WEIGHT_HOME, if not exists,
download it from url.
Args:
url (str): download url
md5sum (str): md5 sum of download package
Returns:
str: a local path to save downloaded weights.
Examples:
.. code-block:: python
from paddle.utils.download import get_weights_path_from_url
resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams'
local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url)
"""
path = get_path_from_url(url, WEIGHTS_HOME, md5sum)
return path
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.split(url)[-1]
fpath = fname
return osp.join(root_dir, fpath)
def _get_unique_endpoints(trainer_endpoints):
# Sorting is to avoid different environmental variables for each card
trainer_endpoints.sort()
ips = set()
unique_endpoints = set()
for endpoint in trainer_endpoints:
ip = endpoint.split(":")[0]
if ip in ips:
continue
ips.add(ip)
unique_endpoints.add(endpoint)
logger.info("unique_endpoints {}".format(unique_endpoints))
return unique_endpoints
def get_path_from_url(url,
root_dir,
md5sum=None,
check_exist=True,
decompress=True,
method='get'):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
decompress (bool): decompress zip or tar file. Default is `True`
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
from paddle.fluid.dygraph.parallel import ParallelEnv
assert is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir
fullpath = _map_path(url, root_dir)
# Mainly used to solve the problem of downloading data from different
# machines in the case of multiple machines. Different ips will download
# data, and the same ip will only download data once.
unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:])
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger.info("Found {}".format(fullpath))
else:
if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum, method=method)
else:
while not os.path.exists(fullpath):
time.sleep(1)
if ParallelEnv().current_endpoint in unique_endpoints:
if decompress and (tarfile.is_tarfile(fullpath) or
zipfile.is_zipfile(fullpath)):
fullpath = _decompress(fullpath)
return fullpath
def _get_download(url, fullname):
# using requests.get method
fname = osp.basename(fullname)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _wget_download(url, fullname):
# using wget to download url
tmp_fullname = fullname + "_tmp"
# user-agent
command = 'wget -O {} -t {} {}'.format(tmp_fullname, DOWNLOAD_RETRY_LIMIT,
url)
subprc = subprocess.Popen(
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
_ = subprc.communicate()
if subprc.returncode != 0:
raise RuntimeError(
'{} failed. Please make sure `wget` is installed or {} exists'.
format(command, url))
shutil.move(tmp_fullname, fullname)
return fullname
_download_methods = {
'get': _get_download,
'wget': _wget_download,
}
def _download(url, path, md5sum=None, method='get'):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
md5sum (str): md5 sum of download package
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
"""
assert method in _download_methods, 'make sure `{}` implemented'.format(
method)
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if not _download_methods[method](url, fullname):
time.sleep(1)
continue
return fullname
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
logger.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
def _uncompress_file_zip(filepath):
files = zipfile.ZipFile(filepath, 'r')
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _uncompress_file_tar(filepath, mode="r:*"):
files = tarfile.open(filepath, mode)
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
for item in file_list:
files.extract(item, file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
for item in file_list:
files.extract(item, os.path.join(file_dir, rootpath))
files.close()
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < -1:
return True
return False
def _is_a_single_dir(file_list):
new_file_list = []
for file_path in file_list:
if '/' in file_path:
file_path = file_path.replace('/', os.sep)
elif '\\' in file_path:
file_path = file_path.replace('\\', os.sep)
new_file_list.append(file_path)
file_name = new_file_list[0].split(os.sep)[0]
for i in range(1, len(new_file_list)):
if file_name != new_file_list[i].split(os.sep)[0]:
return False
return True

@ -0,0 +1,14 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .infer import TTSExecutor

@ -0,0 +1,641 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from typing import Any
from typing import List
from typing import Optional
from typing import Union
import numpy as np
import paddle
import soundfile as sf
import yaml
from yacs.config import CfgNode
from ..executor import BaseExecutor
from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import logger
from ..utils import MODEL_HOME
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore
__all__ = ['TTSExecutor']
pretrained_models = {
# speedyspeech
"speedyspeech_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip',
'md5':
'9edce23b1a87f31b814d9477bf52afbc',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_11400.pdz',
'speech_stats':
'feats_stats.npy',
'phones_dict':
'phone_id_map.txt',
'tones_dict':
'tone_id_map.txt',
},
# fastspeech2
"fastspeech2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip',
'md5':
'637d28a5e53aa60275612ba4393d5f22',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_76000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_ljspeech_ckpt_0.5.zip',
'md5':
'ffed800c93deaf16ca9b3af89bfcd747',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_100000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip',
'md5':
'f4dd4a5f49a4552b77981f544ab3392e',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_96400.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
"fastspeech2_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_vctk_ckpt_0.5.zip',
'md5':
'743e5024ca1e17a88c5c271db9779ba4',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_66200.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
'speaker_dict':
'speaker_id_map.txt',
},
# pwgan
"pwgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip',
'md5':
'2e481633325b5bdf0a3823c714d2c117',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
"pwgan_ljspeech-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_ljspeech_ckpt_0.5.zip',
'md5':
'53610ba9708fd3008ccaf8e99dacbaf0',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_400000.pdz',
'speech_stats':
'pwg_stats.npy',
},
"pwgan_aishell3-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip',
'md5':
'd7598fa41ad362d62f85ffc0f07e3d84',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
"pwgan_vctk-en": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_vctk_ckpt_0.5.zip',
'md5':
'322ca688aec9b127cec2788b65aa3d52',
'config':
'pwg_default.yaml',
'ckpt':
'pwg_snapshot_iter_1000000.pdz',
'speech_stats':
'pwg_stats.npy',
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_baker_finetune_ckpt_0.5.zip',
'md5':
'b69322ab4ea766d955bd3d9af7dc5f2d',
'config':
'finetune.yaml',
'ckpt':
'snapshot_iter_2000000.pdz',
'speech_stats':
'feats_stats.npy',
},
}
model_alias = {
# acoustic model
"speedyspeech":
"paddlespeech.t2s.models.speedyspeech:SpeedySpeech",
"speedyspeech_inference":
"paddlespeech.t2s.models.speedyspeech:SpeedySpeechInference",
"fastspeech2":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2",
"fastspeech2_inference":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
# voc
"pwgan":
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
"pwgan_inference":
"paddlespeech.t2s.models.parallel_wavegan:PWGInference",
"mb_melgan":
"paddlespeech.t2s.models.melgan:MelGANGenerator",
"mb_melgan_inference":
"paddlespeech.t2s.models.melgan:MelGANInference",
}
@cli_register(
name='paddlespeech.tts', description='Text to Speech infer command.')
class TTSExecutor(BaseExecutor):
def __init__(self):
super().__init__()
self.parser = argparse.ArgumentParser(
prog='paddlespeech.tts', add_help=True)
self.parser.add_argument(
'--input', type=str, required=True, help='Input text to generate.')
# acoustic model
self.parser.add_argument(
'--am',
type=str,
default='fastspeech2_csmsc',
choices=[
'speedyspeech_csmsc', 'fastspeech2_csmsc',
'fastspeech2_ljspeech', 'fastspeech2_aishell3',
'fastspeech2_vctk'
],
help='Choose acoustic model type of tts task.')
self.parser.add_argument(
'--am_config',
type=str,
default=None,
help='Config of acoustic model. Use deault config when it is None.')
self.parser.add_argument(
'--am_ckpt',
type=str,
default=None,
help='Checkpoint file of acoustic model.')
self.parser.add_argument(
"--am_stat",
type=str,
help="mean and standard deviation used to normalize spectrogram when training acoustic model."
)
self.parser.add_argument(
"--phones_dict",
type=str,
default=None,
help="phone vocabulary file.")
self.parser.add_argument(
"--tones_dict",
type=str,
default=None,
help="tone vocabulary file.")
self.parser.add_argument(
"--speaker_dict",
type=str,
default=None,
help="speaker id map file.")
self.parser.add_argument(
'--spk_id',
type=int,
default=0,
help='spk id for multi speaker acoustic model')
# vocoder
self.parser.add_argument(
'--voc',
type=str,
default='pwgan_csmsc',
choices=[
'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
'mb_melgan_csmsc'
],
help='Choose vocoder type of tts task.')
self.parser.add_argument(
'--voc_config',
type=str,
default=None,
help='Config of voc. Use deault config when it is None.')
self.parser.add_argument(
'--voc_ckpt',
type=str,
default=None,
help='Checkpoint file of voc.')
self.parser.add_argument(
"--voc_stat",
type=str,
help="mean and standard deviation used to normalize spectrogram when training voc."
)
# other
self.parser.add_argument(
'--lang',
type=str,
default='zh',
help='Choose model language. zh or en')
self.parser.add_argument(
'--device',
type=str,
default=paddle.get_device(),
help='Choose device to execute model inference.')
self.parser.add_argument(
'--output', type=str, default='output.wav', help='output file name')
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
assert tag in pretrained_models, 'Can not find pretrained resources of {}.'.format(
tag)
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(
self,
am: str='fastspeech2_csmsc',
am_config: Optional[os.PathLike]=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
voc: str='pwgan_csmsc',
voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None,
lang: str='zh', ):
"""
Init model and other resources from a specific path.
"""
if hasattr(self, 'am') and hasattr(self, 'voc'):
logger.info('Models had been initialized.')
return
# am
am_tag = am + '-' + lang
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag)
self.am_res_path = am_res_path
self.am_config = os.path.join(am_res_path,
pretrained_models[am_tag]['config'])
self.am_ckpt = os.path.join(am_res_path,
pretrained_models[am_tag]['ckpt'])
self.am_stat = os.path.join(
am_res_path, pretrained_models[am_tag]['speech_stats'])
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
am_res_path, pretrained_models[am_tag]['phones_dict'])
print("self.phones_dict:", self.phones_dict)
logger.info(am_res_path)
logger.info(self.am_config)
logger.info(self.am_ckpt)
else:
self.am_config = os.path.abspath(am_config)
self.am_ckpt = os.path.abspath(am_ckpt)
self.am_stat = os.path.abspath(am_stat)
self.phones_dict = os.path.abspath(phones_dict)
self.am_res_path = os.path.dirname(os.path.abspath(self.am_config))
print("self.phones_dict:", self.phones_dict)
# for speedyspeech
self.tones_dict = None
if 'tones_dict' in pretrained_models[am_tag]:
self.tones_dict = os.path.join(
am_res_path, pretrained_models[am_tag]['tones_dict'])
if tones_dict:
self.tones_dict = tones_dict
# for multi speaker fastspeech2
self.speaker_dict = None
if 'speaker_dict' in pretrained_models[am_tag]:
self.speaker_dict = os.path.join(
am_res_path, pretrained_models[am_tag]['speaker_dict'])
if speaker_dict:
self.speaker_dict = speaker_dict
# voc
voc_tag = voc + '-' + lang
if voc_ckpt is None or voc_config is None or voc_stat is None:
voc_res_path = self._get_pretrained_path(voc_tag)
self.voc_res_path = voc_res_path
self.voc_config = os.path.join(voc_res_path,
pretrained_models[voc_tag]['config'])
self.voc_ckpt = os.path.join(voc_res_path,
pretrained_models[voc_tag]['ckpt'])
self.voc_stat = os.path.join(
voc_res_path, pretrained_models[voc_tag]['speech_stats'])
logger.info(voc_res_path)
logger.info(self.voc_config)
logger.info(self.voc_ckpt)
else:
self.voc_config = os.path.abspath(voc_config)
self.voc_ckpt = os.path.abspath(voc_ckpt)
self.voc_stat = os.path.abspath(voc_stat)
self.voc_res_path = os.path.dirname(
os.path.abspath(self.voc_config))
# Init body.
with open(self.am_config) as f:
self.am_config = CfgNode(yaml.safe_load(f))
with open(self.voc_config) as f:
self.voc_config = CfgNode(yaml.safe_load(f))
# Enter the path of model root
with open(self.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
tone_size = None
if self.tones_dict:
with open(self.tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id)
print("tone_size:", tone_size)
spk_num = None
if self.speaker_dict:
with open(self.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
print("spk_num:", spk_num)
# frontend
if lang == 'zh':
self.frontend = Frontend(
phone_vocab_path=self.phones_dict,
tone_vocab_path=self.tones_dict)
elif lang == 'en':
self.frontend = English(phone_vocab_path=self.phones_dict)
print("frontend done!")
# acoustic model
odim = self.am_config.n_mels
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias)
if am_name == 'fastspeech2':
am = am_class(
idim=vocab_size,
odim=odim,
spk_num=spk_num,
**self.am_config["model"])
elif am_name == 'speedyspeech':
am = am_class(
vocab_size=vocab_size,
tone_size=tone_size,
**self.am_config["model"])
am.set_state_dict(paddle.load(self.am_ckpt)["main_params"])
am.eval()
am_mu, am_std = np.load(self.am_stat)
am_mu = paddle.to_tensor(am_mu)
am_std = paddle.to_tensor(am_std)
am_normalizer = ZScore(am_mu, am_std)
self.am_inference = am_inference_class(am_normalizer, am)
print("acoustic model done!")
# vocoder
# model: {model_name}_{dataset}
voc_name = '_'.join(voc.split('_')[:-1])
voc_class = dynamic_import(voc_name, model_alias)
voc_inference_class = dynamic_import(voc_name + '_inference',
model_alias)
voc = voc_class(**self.voc_config["generator_params"])
voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"])
voc.remove_weight_norm()
voc.eval()
voc_mu, voc_std = np.load(self.voc_stat)
voc_mu = paddle.to_tensor(voc_mu)
voc_std = paddle.to_tensor(voc_std)
voc_normalizer = ZScore(voc_mu, voc_std)
self.voc_inference = voc_inference_class(voc_normalizer, voc)
print("voc done!")
def preprocess(self, input: Any, *args, **kwargs):
"""
Input preprocess and return paddle.Tensor stored in self._inputs.
Input content can be a text(tts), a file(asr, cls), a stream(not supported yet) or anything needed.
Args:
input (Any): Input text/file/stream or other content.
"""
pass
@paddle.no_grad()
def infer(self,
text: str,
lang: str='zh',
am: str='fastspeech2_csmsc',
spk_id: int=0):
"""
Model inference and result stored in self.output.
"""
model_name = am[:am.rindex('_')]
dataset = am[am.rindex('_') + 1:]
get_tone_ids = False
if 'speedyspeech' in model_name:
get_tone_ids = True
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
text, merge_sentences=True, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
phone_ids = phone_ids[0]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
tone_ids = tone_ids[0]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(text)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
# am
if 'speedyspeech' in model_name:
mel = self.am_inference(phone_ids, tone_ids)
# fastspeech2
else:
# multi speaker
if dataset in {"aishell3", "vctk"}:
mel = self.am_inference(
phone_ids, spk_id=paddle.to_tensor(spk_id))
else:
mel = self.am_inference(phone_ids)
# voc
wav = self.voc_inference(mel)
self._outputs['wav'] = wav
def postprocess(self, output: str='output.wav'):
"""
Output postprocess and return results.
This method get model output from self._outputs and convert it into human-readable results.
Returns:
Union[str, os.PathLike]: Human-readable results such as texts and audio files.
"""
sf.write(
output, self._outputs['wav'].numpy(), samplerate=self.am_config.fs)
return output
def execute(self, argv: List[str]) -> bool:
"""
Command line entry.
"""
args = self.parser.parse_args(argv)
text = args.input
am = args.am
am_config = args.am_config
am_ckpt = args.am_ckpt
am_stat = args.am_stat
phones_dict = args.phones_dict
print("phones_dict:", phones_dict)
tones_dict = args.tones_dict
speaker_dict = args.speaker_dict
voc = args.voc
voc_config = args.voc_config
voc_ckpt = args.voc_ckpt
voc_stat = args.voc_stat
lang = args.lang
device = args.device
output = args.output
spk_id = args.spk_id
try:
res = self(
text=text,
# acoustic model related
am=am,
am_config=am_config,
am_ckpt=am_ckpt,
am_stat=am_stat,
phones_dict=phones_dict,
tones_dict=tones_dict,
speaker_dict=speaker_dict,
spk_id=spk_id,
# vocoder related
voc=voc,
voc_config=voc_config,
voc_ckpt=voc_ckpt,
voc_stat=voc_stat,
# other
lang=lang,
device=device,
output=output)
logger.info('TTS Result Saved in: {}'.format(res))
return True
except Exception as e:
logger.exception(e)
return False
def __call__(self,
text: str,
am: str='fastspeech2_csmsc',
am_config: Optional[os.PathLike]=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
spk_id: int=0,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
voc: str='pwgan_csmsc',
voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None,
lang: str='zh',
device: str='gpu',
output: str='output.wav'):
"""
Python API to call an executor.
"""
paddle.set_device(device)
self._init_from_path(
am=am,
am_config=am_config,
am_ckpt=am_ckpt,
am_stat=am_stat,
phones_dict=phones_dict,
tones_dict=tones_dict,
speaker_dict=speaker_dict,
voc=voc,
voc_config=voc_config,
voc_ckpt=voc_ckpt,
voc_stat=voc_stat,
lang=lang)
self.infer(text=text, lang=lang, am=am, spk_id=spk_id)
res = self.postprocess(output=output)
return res

@ -22,8 +22,8 @@ from typing import Dict
from typing import List
from paddle.framework import load
from paddle.utils import download
from . import download
from .entry import commands
__all__ = [
@ -78,7 +78,6 @@ def _md5check(filepath: os.PathLike, md5sum: str) -> bool:
def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike:
file_dir = os.path.dirname(filepath)
if tarfile.is_tarfile(filepath):
files = tarfile.open(filepath, "r:*")
file_list = files.getnames()
@ -87,12 +86,11 @@ def _get_uncompress_path(filepath: os.PathLike) -> os.PathLike:
file_list = files.namelist()
else:
return file_dir
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1]
rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[0]
uncompressed_path = os.path.join(file_dir, rootpath)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]

@ -16,11 +16,10 @@ import os
import numpy as np
from paddle import inference
from scipy.special import softmax
from paddleaudio.backends import load as load_audio
from paddleaudio.datasets import ESC50
from paddleaudio.features import melspectrogram
from scipy.special import softmax
# yapf: disable
parser = argparse.ArgumentParser()

@ -15,8 +15,8 @@ import argparse
import os
import paddle
from paddleaudio.datasets import ESC50
from paddlespeech.cls.models import cnn14
from paddlespeech.cls.models import SoundClassifier

@ -16,11 +16,11 @@ import argparse
import numpy as np
import paddle
import paddle.nn.functional as F
from paddleaudio.backends import load as load_audio
from paddleaudio.datasets import ESC50
from paddleaudio.features import LogMelSpectrogram
from paddleaudio.features import melspectrogram
from paddlespeech.cls.models import cnn14
from paddlespeech.cls.models import SoundClassifier

@ -15,11 +15,11 @@ import argparse
import os
import paddle
from paddleaudio.datasets import ESC50
from paddleaudio.features import LogMelSpectrogram
from paddleaudio.utils import logger
from paddleaudio.utils import Timer
from paddlespeech.cls.models import cnn14
from paddlespeech.cls.models import SoundClassifier

@ -15,7 +15,6 @@ import os
import paddle.nn as nn
import paddle.nn.functional as F
from paddleaudio.utils.download import load_state_dict_from_url
from paddleaudio.utils.env import MODEL_HOME

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -26,10 +26,8 @@ from paddle import distributed as dist
from paddle.io import DataLoader
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.io.collator import TripletSpeechCollator
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.io.sampler import SortagradBatchSampler
from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler

@ -71,8 +71,7 @@ def evaluate(args, fastspeech2_config, pwg_config):
vocoder.eval()
print("model done!")
frontend = English()
punc = ":,;。?!“”‘’':,;.?!"
frontend = English(phone_vocab_path=args.phones_dict)
print("frontend done!")
stat = np.load(args.fastspeech2_stat)
@ -95,16 +94,8 @@ def evaluate(args, fastspeech2_config, pwg_config):
# only test the number 0 speaker
spk_id = 0
for utt_id, sentence in sentences:
phones = frontend.phoneticize(sentence)
# remove start_symbol and end_symbol
phones = phones[1:-1]
phones = [phn for phn in phones if not phn.isspace()]
phones = [
phn if (phn in phone_id_map and phn not in punc) else "sp"
for phn in phones
]
phone_ids = [phone_id_map[phn] for phn in phones]
phone_ids = paddle.to_tensor(phone_ids)
input_ids = frontend.get_input_ids(sentence)
phone_ids = input_ids["phone_ids"]
with paddle.no_grad():
mel = fastspeech2_inference(

@ -63,8 +63,7 @@ def evaluate(args, fastspeech2_config, pwg_config):
vocoder.eval()
print("model done!")
frontend = English()
punc = ":,;。?!“”‘’':,;.?!"
frontend = English(phone_vocab_path=args.phones_dict)
print("frontend done!")
stat = np.load(args.fastspeech2_stat)
@ -86,16 +85,8 @@ def evaluate(args, fastspeech2_config, pwg_config):
output_dir.mkdir(parents=True, exist_ok=True)
for utt_id, sentence in sentences:
phones = frontend.phoneticize(sentence)
# remove start_symbol and end_symbol
phones = phones[1:-1]
phones = [phn for phn in phones if not phn.isspace()]
phones = [
phn if (phn in phone_id_map and phn not in punc) else "sp"
for phn in phones
]
phone_ids = [phone_id_map[phn] for phn in phones]
phone_ids = paddle.to_tensor(phone_ids)
input_ids = frontend.get_input_ids(sentence)
phone_ids = input_ids["phone_ids"]
with paddle.no_grad():
mel = fastspeech2_inference(phone_ids)

@ -14,6 +14,7 @@
from abc import ABC
from abc import abstractmethod
import paddle
from g2p_en import G2p
from g2pM import G2pM
@ -45,20 +46,25 @@ class English(Phonetics):
""" Normalize the input text sequence and convert into pronunciation id sequence.
"""
def __init__(self):
def __init__(self, phone_vocab_path=None):
self.backend = G2p()
self.phonemes = list(self.backend.phonemes)
self.punctuations = get_punctuations("en")
self.vocab = Vocab(self.phonemes + self.punctuations)
self.vocab_phones = {}
self.punc = ":,;。?!“”‘’':,;.?!"
if phone_vocab_path:
with open(phone_vocab_path, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
self.vocab_phones[phn] = int(id)
def phoneticize(self, sentence):
""" Normalize the input text sequence and convert it into pronunciation sequence.
Parameters
-----------
sentence: str
The input text sequence.
Returns
----------
List[str]
@ -72,14 +78,27 @@ class English(Phonetics):
phonemes = [item for item in phonemes if item in self.vocab.stoi]
return phonemes
def get_input_ids(self, sentence: str) -> paddle.Tensor:
result = {}
phones = self.phoneticize(sentence)
# remove start_symbol and end_symbol
phones = phones[1:-1]
phones = [phn for phn in phones if not phn.isspace()]
phones = [
phn if (phn in self.vocab_phones and phn not in self.punc) else "sp"
for phn in phones
]
phone_ids = [self.vocab_phones[phn] for phn in phones]
phone_ids = paddle.to_tensor(phone_ids)
result["phone_ids"] = phone_ids
return result
def numericalize(self, phonemes):
""" Convert pronunciation sequence into pronunciation id sequence.
Parameters
-----------
phonemes: List[str]
The list of pronunciation sequence.
Returns
----------
List[int]
@ -93,12 +112,10 @@ class English(Phonetics):
def reverse(self, ids):
""" Reverse the list of pronunciation id sequence to a list of pronunciation sequence.
Parameters
-----------
ids: List[int]
The list of pronunciation id sequence.
Returns
----------
List[str]
@ -108,12 +125,10 @@ class English(Phonetics):
def __call__(self, sentence):
""" Convert the input text sequence into pronunciation id sequence.
Parameters
-----------
sentence: str
The input text sequence.
Returns
----------
List[str]
@ -140,12 +155,10 @@ class EnglishCharacter(Phonetics):
def phoneticize(self, sentence):
""" Normalize the input text sequence.
Parameters
-----------
sentence: str
The input text sequence.
Returns
----------
str
@ -156,12 +169,10 @@ class EnglishCharacter(Phonetics):
def numericalize(self, sentence):
""" Convert a text sequence into ids.
Parameters
-----------
sentence: str
The input text sequence.
Returns
----------
List[int]
@ -175,17 +186,14 @@ class EnglishCharacter(Phonetics):
def reverse(self, ids):
""" Convert a character id sequence into text.
Parameters
-----------
ids: List[int]
List of a character id sequence.
Returns
----------
str
The input text sequence.
"""
return [self.vocab.reverse(i) for i in ids]
@ -195,7 +203,6 @@ class EnglishCharacter(Phonetics):
-----------
sentence: str
The input text sequence.
Returns
----------
List[int]
@ -229,12 +236,10 @@ class Chinese(Phonetics):
def phoneticize(self, sentence):
""" Normalize the input text sequence and convert it into pronunciation sequence.
Parameters
-----------
sentence: str
The input text sequence.
Returns
----------
List[str]
@ -263,12 +268,10 @@ class Chinese(Phonetics):
def numericalize(self, phonemes):
""" Convert pronunciation sequence into pronunciation id sequence.
Parameters
-----------
phonemes: List[str]
The list of pronunciation sequence.
Returns
----------
List[int]
@ -279,12 +282,10 @@ class Chinese(Phonetics):
def __call__(self, sentence):
""" Convert the input text sequence into pronunciation id sequence.
Parameters
-----------
sentence: str
The input text sequence.
Returns
----------
List[str]
@ -300,12 +301,10 @@ class Chinese(Phonetics):
def reverse(self, ids):
""" Reverse the list of pronunciation id sequence to a list of pronunciation sequence.
Parameters
-----------
ids: List[int]
The list of pronunciation id sequence.
Returns
----------
List[str]

@ -5,7 +5,6 @@ import functools
from pathlib import Path
import jsonlines
from utils.utility import add_arguments
from utils.utility import print_arguments

Loading…
Cancel
Save