Merge pull request #2040 from yt605155624/add_blank

[TTS]add blank between characters for vits
pull/2056/head
TianYuan 3 years ago committed by GitHub
commit 02734141ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -178,6 +178,8 @@ generator_first: False # whether to start updating generator first
########################################################## ##########################################################
# OTHER TRAINING SETTING # # OTHER TRAINING SETTING #
########################################################## ##########################################################
max_epoch: 1000 # number of epochs num_snapshots: 10 # max number of snapshots to keep while training
num_snapshots: 10 # max number of snapshots to keep while training train_max_steps: 250000 # Number of training steps. == total_iters / ngpus, total_iters = 1000000
seed: 777 # random seed number save_interval_steps: 1000 # Interval steps to save checkpoint.
eval_interval_steps: 250 # Interval steps to evaluate the network.
seed: 777 # random seed number

@ -4,6 +4,7 @@ stage=0
stop_stage=100 stop_stage=100
config_path=$1 config_path=$1
add_blank=$2
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# get durations from MFA's result # get durations from MFA's result
@ -44,6 +45,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--feats-stats=dump/train/feats_stats.npy \ --feats-stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \ --phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \ --speaker-dict=dump/speaker_id_map.txt \
--add-blank=${add_blank} \
--skip-wav-copy --skip-wav-copy
python3 ${BIN_DIR}/normalize.py \ python3 ${BIN_DIR}/normalize.py \
@ -52,6 +54,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--feats-stats=dump/train/feats_stats.npy \ --feats-stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \ --phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \ --speaker-dict=dump/speaker_id_map.txt \
--add-blank=${add_blank} \
--skip-wav-copy --skip-wav-copy
python3 ${BIN_DIR}/normalize.py \ python3 ${BIN_DIR}/normalize.py \
@ -60,5 +63,6 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--feats-stats=dump/train/feats_stats.npy \ --feats-stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \ --phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \ --speaker-dict=dump/speaker_id_map.txt \
--add-blank=${add_blank} \
--skip-wav-copy --skip-wav-copy
fi fi

@ -3,9 +3,12 @@
config_path=$1 config_path=$1
train_output_path=$2 train_output_path=$2
ckpt_name=$3 ckpt_name=$3
add_blank=$4
stage=0 stage=0
stop_stage=0 stop_stage=0
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
FLAGS_allocator_strategy=naive_best_fit \ FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
@ -14,5 +17,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--ckpt=${train_output_path}/checkpoints/${ckpt_name} \ --ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--phones_dict=dump/phone_id_map.txt \ --phones_dict=dump/phone_id_map.txt \
--output_dir=${train_output_path}/test_e2e \ --output_dir=${train_output_path}/test_e2e \
--text=${BIN_DIR}/../sentences.txt --text=${BIN_DIR}/../sentences.txt \
--add-blank=${add_blank}
fi fi

@ -10,6 +10,7 @@ stop_stage=100
conf_path=conf/default.yaml conf_path=conf/default.yaml
train_output_path=exp/default train_output_path=exp/default
ckpt_name=snapshot_iter_153.pdz ckpt_name=snapshot_iter_153.pdz
add_blank=true
# with the following command, you can choose the stage range you want to run # with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0` # such as `./run.sh --stage 0 --stop-stage 0`
@ -18,7 +19,7 @@ source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data # prepare data
./local/preprocess.sh ${conf_path} || exit -1 ./local/preprocess.sh ${conf_path} ${add_blank}|| exit -1
fi fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
@ -32,5 +33,5 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# synthesize_e2e, vocoder is pwgan # synthesize_e2e, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} ${add_blank}|| exit -1
fi fi

@ -8,5 +8,4 @@ python ${BIN_DIR}/synthesize.py \
--input=${input_mel_path} \ --input=${input_mel_path} \
--output=${train_output_path}/wavs/ \ --output=${train_output_path}/wavs/ \
--checkpoint_path=${train_output_path}/checkpoints/${ckpt_name} \ --checkpoint_path=${train_output_path}/checkpoints/${ckpt_name} \
--ngpu=1 \ --ngpu=1
--verbose

@ -58,30 +58,8 @@ def main():
"--phones-dict", type=str, default=None, help="phone vocabulary file.") "--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument( parser.add_argument(
"--speaker-dict", type=str, default=None, help="speaker id map file.") "--speaker-dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
args = parser.parse_args()
# set logger args = parser.parse_args()
if args.verbose > 1:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
elif args.verbose > 0:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
logging.warning('Skip DEBUG/INFO messages')
dumpdir = Path(args.dumpdir).expanduser() dumpdir = Path(args.dumpdir).expanduser()
# use absolute path # use absolute path

@ -209,11 +209,6 @@ def main():
parser.add_argument("--config", type=str, help="fastspeech2 config file.") parser.add_argument("--config", type=str, help="fastspeech2 config file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
parser.add_argument( parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.") "--num-cpu", type=int, default=1, help="number of process.")
@ -248,10 +243,6 @@ def main():
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f)) config = CfgNode(yaml.safe_load(f))
if args.verbose > 1:
print(vars(args))
print(config)
sentences, speaker_set = get_phn_dur(dur_file) sentences, speaker_set = get_phn_dur(dur_file)
merge_silence(sentences) merge_silence(sentences)

@ -47,30 +47,8 @@ def main():
default=False, default=False,
action="store_true", action="store_true",
help="whether to skip the copy of wav files.") help="whether to skip the copy of wav files.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
args = parser.parse_args()
# set logger args = parser.parse_args()
if args.verbose > 1:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
elif args.verbose > 0:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
logging.warning('Skip DEBUG/INFO messages')
dumpdir = Path(args.dumpdir).expanduser() dumpdir = Path(args.dumpdir).expanduser()
# use absolute path # use absolute path

@ -167,11 +167,6 @@ def main():
required=True, required=True,
help="directory to dump feature files.") help="directory to dump feature files.")
parser.add_argument("--config", type=str, help="vocoder config file.") parser.add_argument("--config", type=str, help="vocoder config file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
parser.add_argument( parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.") "--num-cpu", type=int, default=1, help="number of process.")
parser.add_argument( parser.add_argument(
@ -197,10 +192,6 @@ def main():
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f)) config = CfgNode(yaml.safe_load(f))
if args.verbose > 1:
print(vars(args))
print(config)
sentences, speaker_set = get_phn_dur(dur_file) sentences, speaker_set = get_phn_dur(dur_file)
merge_silence(sentences) merge_silence(sentences)

@ -50,11 +50,6 @@ def main():
"--tones-dict", type=str, default=None, help="tone vocabulary file.") "--tones-dict", type=str, default=None, help="tone vocabulary file.")
parser.add_argument( parser.add_argument(
"--speaker-dict", type=str, default=None, help="speaker id map file.") "--speaker-dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
parser.add_argument( parser.add_argument(
"--use-relative-path", "--use-relative-path",
@ -63,24 +58,6 @@ def main():
help="whether use relative path in metadata") help="whether use relative path in metadata")
args = parser.parse_args() args = parser.parse_args()
# set logger
if args.verbose > 1:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
elif args.verbose > 0:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
logging.warning('Skip DEBUG/INFO messages')
dumpdir = Path(args.dumpdir).expanduser() dumpdir = Path(args.dumpdir).expanduser()
# use absolute path # use absolute path
dumpdir = dumpdir.resolve() dumpdir = dumpdir.resolve()

@ -195,11 +195,6 @@ def main():
parser.add_argument("--config", type=str, help="fastspeech2 config file.") parser.add_argument("--config", type=str, help="fastspeech2 config file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
parser.add_argument( parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.") "--num-cpu", type=int, default=1, help="number of process.")
@ -230,10 +225,6 @@ def main():
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f)) config = CfgNode(yaml.safe_load(f))
if args.verbose > 1:
print(vars(args))
print(config)
sentences, speaker_set = get_phn_dur(dur_file) sentences, speaker_set = get_phn_dur(dur_file)
merge_silence(sentences) merge_silence(sentences)

@ -184,11 +184,6 @@ def main():
parser.add_argument("--config", type=str, help="fastspeech2 config file.") parser.add_argument("--config", type=str, help="fastspeech2 config file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
parser.add_argument( parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.") "--num-cpu", type=int, default=1, help="number of process.")
@ -223,10 +218,6 @@ def main():
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f)) config = CfgNode(yaml.safe_load(f))
if args.verbose > 1:
print(vars(args))
print(config)
sentences, speaker_set = get_phn_dur(dur_file) sentences, speaker_set = get_phn_dur(dur_file)
merge_silence(sentences) merge_silence(sentences)

@ -51,30 +51,8 @@ def main():
"--phones-dict", type=str, default=None, help="phone vocabulary file.") "--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument( parser.add_argument(
"--speaker-dict", type=str, default=None, help="speaker id map file.") "--speaker-dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
args = parser.parse_args()
# set logger args = parser.parse_args()
if args.verbose > 1:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
elif args.verbose > 0:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
logging.warning('Skip DEBUG/INFO messages')
# check directory existence # check directory existence
dumpdir = Path(args.dumpdir).resolve() dumpdir = Path(args.dumpdir).resolve()

@ -186,11 +186,6 @@ def main():
type=str, type=str,
help="yaml format configuration file.") help="yaml format configuration file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
parser.add_argument( parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.") "--num-cpu", type=int, default=1, help="number of process.")
@ -210,10 +205,6 @@ def main():
_C = Configuration(_C) _C = Configuration(_C)
config = _C.clone() config = _C.clone()
if args.verbose > 1:
print(vars(args))
print(config)
phone_id_map_path = dumpdir / "phone_id_map.txt" phone_id_map_path = dumpdir / "phone_id_map.txt"
speaker_id_map_path = dumpdir / "speaker_id_map.txt" speaker_id_map_path = dumpdir / "speaker_id_map.txt"

@ -16,6 +16,7 @@ import argparse
import logging import logging
from operator import itemgetter from operator import itemgetter
from pathlib import Path from pathlib import Path
from typing import List
import jsonlines import jsonlines
import numpy as np import numpy as np
@ -23,6 +24,50 @@ from sklearn.preprocessing import StandardScaler
from tqdm import tqdm from tqdm import tqdm
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.utils import str2bool
INITIALS = [
'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh',
'r', 'z', 'c', 's', 'j', 'q', 'x'
]
INITIALS += ['y', 'w', 'sp', 'spl', 'spn', 'sil']
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def insert_after_character(lst, item):
result = [item]
for phone in lst:
result.append(phone)
if phone not in INITIALS:
# finals has tones
assert phone[-1] in "12345"
result.append(item)
return result
def add_blank(phones: List[str],
filed: str="character",
blank_token: str="<pad>"):
if filed == "phone":
"""
add blank after phones
input: ["n", "i3", "h", "ao3", "m", "a5"]
output: ["n", "<pad>", "i3", "<pad>", "h", "<pad>", "ao3", "<pad>", "m", "<pad>", "a5"]
"""
phones = intersperse(phones, blank_token)
elif filed == "character":
"""
add blank after characters
input: ["n", "i3", "h", "ao3"]
output: ["n", "i3", "<pad>", "h", "ao3", "<pad>", "m", "a5"]
"""
phones = insert_after_character(phones, blank_token)
return phones
def main(): def main():
@ -58,29 +103,12 @@ def main():
parser.add_argument( parser.add_argument(
"--speaker-dict", type=str, default=None, help="speaker id map file.") "--speaker-dict", type=str, default=None, help="speaker id map file.")
parser.add_argument( parser.add_argument(
"--verbose", "--add-blank",
type=int, type=str2bool,
default=1, default=True,
help="logging level. higher is more logging. (default=1)") help="whether to add blank between phones")
args = parser.parse_args()
# set logger args = parser.parse_args()
if args.verbose > 1:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
elif args.verbose > 0:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
logging.warning('Skip DEBUG/INFO messages')
dumpdir = Path(args.dumpdir).expanduser() dumpdir = Path(args.dumpdir).expanduser()
# use absolute path # use absolute path
@ -135,13 +163,19 @@ def main():
else: else:
wav_path = wave wav_path = wave
phone_ids = [vocab_phones[p] for p in item['phones']] phones = item['phones']
text_lengths = item['text_lengths']
if args.add_blank:
phones = add_blank(phones, filed="character")
text_lengths = len(phones)
phone_ids = [vocab_phones[p] for p in phones]
spk_id = vocab_speaker[item["speaker"]] spk_id = vocab_speaker[item["speaker"]]
record = { record = {
"utt_id": item['utt_id'], "utt_id": item['utt_id'],
"text": phone_ids, "text": phone_ids,
"text_lengths": item['text_lengths'], "text_lengths": text_lengths,
'feats': str(feats_path), 'feats': str(feats_path),
"feats_lengths": item['feats_lengths'], "feats_lengths": item['feats_lengths'],
"wave": str(wav_path), "wave": str(wav_path),

@ -197,11 +197,6 @@ def main():
parser.add_argument("--config", type=str, help="fastspeech2 config file.") parser.add_argument("--config", type=str, help="fastspeech2 config file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
parser.add_argument( parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.") "--num-cpu", type=int, default=1, help="number of process.")
@ -236,10 +231,6 @@ def main():
with open(args.config, 'rt') as f: with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f)) config = CfgNode(yaml.safe_load(f))
if args.verbose > 1:
print(vars(args))
print(config)
sentences, speaker_set = get_phn_dur(dur_file) sentences, speaker_set = get_phn_dur(dur_file)
merge_silence(sentences) merge_silence(sentences)

@ -23,6 +23,7 @@ from yacs.config import CfgNode
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.models.vits import VITS from paddlespeech.t2s.models.vits import VITS
from paddlespeech.t2s.utils import str2bool
def evaluate(args): def evaluate(args):
@ -55,6 +56,7 @@ def evaluate(args):
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
merge_sentences = False merge_sentences = False
add_blank = args.add_blank
N = 0 N = 0
T = 0 T = 0
@ -62,7 +64,9 @@ def evaluate(args):
with timer() as t: with timer() as t:
if args.lang == 'zh': if args.lang == 'zh':
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences) sentence,
merge_sentences=merge_sentences,
add_blank=add_blank)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
elif args.lang == 'en': elif args.lang == 'en':
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
@ -125,6 +129,12 @@ def parse_args():
help="text to synthesize, a 'utt_id sentence' pair per line.") help="text to synthesize, a 'utt_id sentence' pair per line.")
parser.add_argument("--output_dir", type=str, help="output dir.") parser.add_argument("--output_dir", type=str, help="output dir.")
parser.add_argument(
"--add-blank",
type=str2bool,
default=True,
help="whether to add blank between phones")
args = parser.parse_args() args = parser.parse_args()
return args return args

@ -211,13 +211,18 @@ def train_sp(args, config):
generator_first=config.generator_first, generator_first=config.generator_first,
output_dir=output_dir) output_dir=output_dir)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) trainer = Trainer(
updater,
stop_trigger=(config.train_max_steps, "iteration"),
out=output_dir)
if dist.get_rank() == 0: if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch")) trainer.extend(
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) evaluator, trigger=(config.eval_interval_steps, 'iteration'))
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
trainer.extend( trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
print("Trainer Done!") print("Trainer Done!")
trainer.run() trainer.run()

@ -143,8 +143,6 @@ if __name__ == "__main__":
nargs=argparse.REMAINDER, nargs=argparse.REMAINDER,
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs" help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
) )
parser.add_argument(
"-v", "--verbose", action="store_true", help="print msg")
config = get_cfg_defaults() config = get_cfg_defaults()
args = parser.parse_args() args = parser.parse_args()
@ -153,8 +151,5 @@ if __name__ == "__main__":
if args.opts: if args.opts:
config.merge_from_list(args.opts) config.merge_from_list(args.opts)
config.freeze() config.freeze()
if args.verbose:
print(config.data)
print(args)
create_dataset(config.data, args.input, args.output) create_dataset(config.data, args.input, args.output)

@ -72,8 +72,6 @@ if __name__ == "__main__":
nargs=argparse.REMAINDER, nargs=argparse.REMAINDER,
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs" help="options to overwrite --config file and the default config, passing in KEY VALUE pairs"
) )
parser.add_argument(
"-v", "--verbose", action="store_true", help="print msg")
args = parser.parse_args() args = parser.parse_args()
if args.config: if args.config:

@ -29,6 +29,29 @@ from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon
from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi
from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer
INITIALS = [
'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh',
'r', 'z', 'c', 's', 'j', 'q', 'x'
]
INITIALS += ['y', 'w', 'sp', 'spl', 'spn', 'sil']
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def insert_after_character(lst, item):
result = [item]
for phone in lst:
result.append(phone)
if phone not in INITIALS:
# finals has tones
# assert phone[-1] in "12345"
result.append(item)
return result
class Frontend(): class Frontend():
def __init__(self, def __init__(self,
@ -280,12 +303,15 @@ class Frontend():
print("----------------------------") print("----------------------------")
return phonemes return phonemes
def get_input_ids(self, def get_input_ids(
sentence: str, self,
merge_sentences: bool=True, sentence: str,
get_tone_ids: bool=False, merge_sentences: bool=True,
robot: bool=False, get_tone_ids: bool=False,
print_info: bool=False) -> Dict[str, List[paddle.Tensor]]: robot: bool=False,
print_info: bool=False,
add_blank: bool=False,
blank_token: str="<pad>") -> Dict[str, List[paddle.Tensor]]:
phonemes = self.get_phonemes( phonemes = self.get_phonemes(
sentence, sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
@ -299,6 +325,10 @@ class Frontend():
for part_phonemes in phonemes: for part_phonemes in phonemes:
phones, tones = self._get_phone_tone( phones, tones = self._get_phone_tone(
part_phonemes, get_tone_ids=get_tone_ids) part_phonemes, get_tone_ids=get_tone_ids)
if add_blank:
phones = insert_after_character(phones, blank_token)
if tones: if tones:
tone_ids = self._t2id(tones) tone_ids = self._t2id(tones)
tone_ids = paddle.to_tensor(tone_ids) tone_ids = paddle.to_tensor(tone_ids)

@ -227,11 +227,7 @@ class VITS(nn.Layer):
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
forward_generator (bool): Whether to forward generator. forward_generator (bool): Whether to forward generator.
Returns: Returns:
Dict[str, Any]:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
- weight (Tensor): Weight tensor to summarize losses.
- optim_idx (int): Optimizer index (0 for G and 1 for D).
""" """
if forward_generator: if forward_generator:
return self._forward_generator( return self._forward_generator(

Loading…
Cancel
Save