You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
480 lines
15 KiB
480 lines
15 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import argparse
|
|
import os
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
import librosa
|
|
import numpy as np
|
|
import paddle
|
|
import pypinyin
|
|
import soundfile as sf
|
|
import yaml
|
|
from pypinyin_dict.phrase_pinyin_data import large_pinyin
|
|
from yacs.config import CfgNode
|
|
|
|
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
|
|
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
|
|
from paddlespeech.t2s.exps.ernie_sat.align import get_phns_spans
|
|
from paddlespeech.t2s.exps.ernie_sat.utils import eval_durs
|
|
from paddlespeech.t2s.exps.ernie_sat.utils import get_dur_adj_factor
|
|
from paddlespeech.t2s.exps.ernie_sat.utils import get_span_bdy
|
|
from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name
|
|
from paddlespeech.t2s.exps.syn_utils import get_am_inference
|
|
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
|
|
from paddlespeech.t2s.exps.syn_utils import norm
|
|
from paddlespeech.t2s.utils import str2bool
|
|
large_pinyin.load()
|
|
|
|
|
|
def _p2id(phonemes: List[str]) -> np.ndarray:
|
|
# replace unk phone with sp
|
|
phonemes = [phn if phn in vocab_phones else "sp" for phn in phonemes]
|
|
phone_ids = [vocab_phones[item] for item in phonemes]
|
|
return np.array(phone_ids, np.int64)
|
|
|
|
|
|
def prep_feats_with_dur(wav_path: str,
|
|
old_str: str='',
|
|
new_str: str='',
|
|
source_lang: str='en',
|
|
target_lang: str='en',
|
|
duration_adjust: bool=True,
|
|
fs: int=24000,
|
|
n_shift: int=300):
|
|
'''
|
|
Returns:
|
|
np.ndarray: new wav, replace the part to be edited in original wav with 0
|
|
List[str]: new phones
|
|
List[float]: mfa start of new wav
|
|
List[float]: mfa end of new wav
|
|
List[int]: masked mel boundary of original wav
|
|
List[int]: masked mel boundary of new wav
|
|
'''
|
|
wav_org, _ = librosa.load(wav_path, sr=fs)
|
|
phns_spans_outs = get_phns_spans(
|
|
wav_path=wav_path,
|
|
old_str=old_str,
|
|
new_str=new_str,
|
|
source_lang=source_lang,
|
|
target_lang=target_lang,
|
|
fs=fs,
|
|
n_shift=n_shift)
|
|
|
|
mfa_start = phns_spans_outs['mfa_start']
|
|
mfa_end = phns_spans_outs['mfa_end']
|
|
old_phns = phns_spans_outs['old_phns']
|
|
new_phns = phns_spans_outs['new_phns']
|
|
span_to_repl = phns_spans_outs['span_to_repl']
|
|
span_to_add = phns_spans_outs['span_to_add']
|
|
|
|
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
|
|
if target_lang in {'en', 'zh'}:
|
|
old_durs = eval_durs(old_phns, target_lang=source_lang)
|
|
else:
|
|
assert target_lang in {'en', 'zh'}, \
|
|
"calculate duration_predict is not support for this language..."
|
|
|
|
orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)]
|
|
|
|
if duration_adjust:
|
|
d_factor = get_dur_adj_factor(
|
|
orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
|
|
d_factor = d_factor * 1.25
|
|
else:
|
|
d_factor = 1
|
|
|
|
if target_lang in {'en', 'zh'}:
|
|
new_durs = eval_durs(new_phns, target_lang=target_lang)
|
|
else:
|
|
assert target_lang == "zh" or target_lang == "en", \
|
|
"calculate duration_predict is not support for this language..."
|
|
|
|
# duration 要是整数
|
|
new_durs_adjusted = [int(np.ceil(d_factor * i)) for i in new_durs]
|
|
|
|
new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]])
|
|
old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]])
|
|
dur_offset = new_span_dur_sum - old_span_dur_sum
|
|
new_mfa_start = mfa_start[:span_to_repl[0]]
|
|
new_mfa_end = mfa_end[:span_to_repl[0]]
|
|
|
|
for dur in new_durs_adjusted[span_to_add[0]:span_to_add[1]]:
|
|
if len(new_mfa_end) == 0:
|
|
new_mfa_start.append(0)
|
|
new_mfa_end.append(dur)
|
|
else:
|
|
new_mfa_start.append(new_mfa_end[-1])
|
|
new_mfa_end.append(new_mfa_end[-1] + dur)
|
|
|
|
new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]]
|
|
new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]]
|
|
|
|
# 3. get new wav
|
|
# 在原始句子后拼接
|
|
if span_to_repl[0] >= len(mfa_start):
|
|
wav_left_idx = len(wav_org)
|
|
wav_right_idx = wav_left_idx
|
|
# 在原始句子中间替换
|
|
else:
|
|
wav_left_idx = int(np.floor(mfa_start[span_to_repl[0]] * n_shift))
|
|
wav_right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * n_shift))
|
|
blank_wav = np.zeros(
|
|
(int(np.ceil(new_span_dur_sum * n_shift)), ), dtype=wav_org.dtype)
|
|
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
|
|
new_wav = np.concatenate(
|
|
[wav_org[:wav_left_idx], blank_wav, wav_org[wav_right_idx:]])
|
|
|
|
# 4. get old and new mel span to be mask
|
|
old_span_bdy = get_span_bdy(
|
|
mfa_start=mfa_start, mfa_end=mfa_end, span_to_repl=span_to_repl)
|
|
|
|
new_span_bdy = get_span_bdy(
|
|
mfa_start=new_mfa_start, mfa_end=new_mfa_end, span_to_repl=span_to_add)
|
|
|
|
# old_span_bdy, new_span_bdy 是帧级别的范围
|
|
outs = {}
|
|
outs['new_wav'] = new_wav
|
|
outs['new_phns'] = new_phns
|
|
outs['new_mfa_start'] = new_mfa_start
|
|
outs['new_mfa_end'] = new_mfa_end
|
|
outs['old_span_bdy'] = old_span_bdy
|
|
outs['new_span_bdy'] = new_span_bdy
|
|
return outs
|
|
|
|
|
|
def prep_feats(wav_path: str,
|
|
old_str: str='',
|
|
new_str: str='',
|
|
source_lang: str='en',
|
|
target_lang: str='en',
|
|
duration_adjust: bool=True,
|
|
fs: int=24000,
|
|
n_shift: int=300):
|
|
|
|
with_dur_outs = prep_feats_with_dur(
|
|
wav_path=wav_path,
|
|
old_str=old_str,
|
|
new_str=new_str,
|
|
source_lang=source_lang,
|
|
target_lang=target_lang,
|
|
duration_adjust=duration_adjust,
|
|
fs=fs,
|
|
n_shift=n_shift)
|
|
|
|
wav_name = os.path.basename(wav_path)
|
|
utt_id = wav_name.split('.')[0]
|
|
|
|
wav = with_dur_outs['new_wav']
|
|
phns = with_dur_outs['new_phns']
|
|
mfa_start = with_dur_outs['new_mfa_start']
|
|
mfa_end = with_dur_outs['new_mfa_end']
|
|
old_span_bdy = with_dur_outs['old_span_bdy']
|
|
new_span_bdy = with_dur_outs['new_span_bdy']
|
|
span_bdy = np.array(new_span_bdy)
|
|
|
|
mel = mel_extractor.get_log_mel_fbank(wav)
|
|
erniesat_mean, erniesat_std = np.load(erniesat_stat)
|
|
normed_mel = norm(mel, erniesat_mean, erniesat_std)
|
|
tmp_name = get_tmp_name(text=old_str)
|
|
tmpbase = './tmp_dir/' + tmp_name
|
|
tmpbase = Path(tmpbase)
|
|
tmpbase.mkdir(parents=True, exist_ok=True)
|
|
|
|
mel_path = tmpbase / 'mel.npy'
|
|
np.save(mel_path, normed_mel)
|
|
durations = [e - s for e, s in zip(mfa_end, mfa_start)]
|
|
text = _p2id(phns)
|
|
|
|
datum = {
|
|
"utt_id": utt_id,
|
|
"spk_id": 0,
|
|
"text": text,
|
|
"text_lengths": len(text),
|
|
"speech_lengths": len(normed_mel),
|
|
"durations": durations,
|
|
"speech": np.load(mel_path),
|
|
"align_start": mfa_start,
|
|
"align_end": mfa_end,
|
|
"span_bdy": span_bdy
|
|
}
|
|
|
|
batch = collate_fn([datum])
|
|
outs = dict()
|
|
outs['batch'] = batch
|
|
outs['old_span_bdy'] = old_span_bdy
|
|
outs['new_span_bdy'] = new_span_bdy
|
|
return outs
|
|
|
|
|
|
def get_mlm_output(wav_path: str,
|
|
old_str: str='',
|
|
new_str: str='',
|
|
source_lang: str='en',
|
|
target_lang: str='en',
|
|
duration_adjust: bool=True,
|
|
fs: int=24000,
|
|
n_shift: int=300):
|
|
|
|
prep_feats_outs = prep_feats(
|
|
wav_path=wav_path,
|
|
old_str=old_str,
|
|
new_str=new_str,
|
|
source_lang=source_lang,
|
|
target_lang=target_lang,
|
|
duration_adjust=duration_adjust,
|
|
fs=fs,
|
|
n_shift=n_shift)
|
|
|
|
batch = prep_feats_outs['batch']
|
|
new_span_bdy = prep_feats_outs['new_span_bdy']
|
|
old_span_bdy = prep_feats_outs['old_span_bdy']
|
|
|
|
out_mels = erniesat_inference(
|
|
speech=batch['speech'],
|
|
text=batch['text'],
|
|
masked_pos=batch['masked_pos'],
|
|
speech_mask=batch['speech_mask'],
|
|
text_mask=batch['text_mask'],
|
|
speech_seg_pos=batch['speech_seg_pos'],
|
|
text_seg_pos=batch['text_seg_pos'],
|
|
span_bdy=new_span_bdy)
|
|
|
|
# 拼接音频
|
|
output_feat = paddle.concat(x=out_mels, axis=0)
|
|
wav_org, _ = librosa.load(wav_path, sr=fs)
|
|
outs = dict()
|
|
outs['wav_org'] = wav_org
|
|
outs['output_feat'] = output_feat
|
|
outs['old_span_bdy'] = old_span_bdy
|
|
outs['new_span_bdy'] = new_span_bdy
|
|
|
|
return outs
|
|
|
|
|
|
def get_wav(wav_path: str,
|
|
source_lang: str='en',
|
|
target_lang: str='en',
|
|
old_str: str='',
|
|
new_str: str='',
|
|
duration_adjust: bool=True,
|
|
fs: int=24000,
|
|
n_shift: int=300,
|
|
task_name: str='synthesize'):
|
|
|
|
outs = get_mlm_output(
|
|
wav_path=wav_path,
|
|
old_str=old_str,
|
|
new_str=new_str,
|
|
source_lang=source_lang,
|
|
target_lang=target_lang,
|
|
duration_adjust=duration_adjust,
|
|
fs=fs,
|
|
n_shift=n_shift)
|
|
|
|
wav_org = outs['wav_org']
|
|
output_feat = outs['output_feat']
|
|
old_span_bdy = outs['old_span_bdy']
|
|
new_span_bdy = outs['new_span_bdy']
|
|
|
|
masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
|
|
|
|
with paddle.no_grad():
|
|
alt_wav = voc_inference(masked_feat)
|
|
alt_wav = np.squeeze(alt_wav)
|
|
|
|
old_time_bdy = [n_shift * x for x in old_span_bdy]
|
|
if task_name == 'edit':
|
|
wav_replaced = np.concatenate(
|
|
[wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]])
|
|
else:
|
|
wav_replaced = alt_wav
|
|
wav_dict = {"origin": wav_org, "output": wav_replaced}
|
|
return wav_dict
|
|
|
|
|
|
def parse_args():
|
|
# parse args and config
|
|
parser = argparse.ArgumentParser(
|
|
description="Synthesize with acoustic model & vocoder")
|
|
# ernie sat
|
|
|
|
parser.add_argument(
|
|
'--erniesat_config',
|
|
type=str,
|
|
default=None,
|
|
help='Config of acoustic model.')
|
|
parser.add_argument(
|
|
'--erniesat_ckpt',
|
|
type=str,
|
|
default=None,
|
|
help='Checkpoint file of acoustic model.')
|
|
parser.add_argument(
|
|
"--erniesat_stat",
|
|
type=str,
|
|
default=None,
|
|
help="mean and standard deviation used to normalize spectrogram when training acoustic model."
|
|
)
|
|
parser.add_argument(
|
|
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
|
|
# vocoder
|
|
parser.add_argument(
|
|
'--voc',
|
|
type=str,
|
|
default='pwgan_csmsc',
|
|
choices=[
|
|
'pwgan_aishell3',
|
|
'pwgan_vctk',
|
|
'hifigan_aishell3',
|
|
'hifigan_vctk',
|
|
],
|
|
help='Choose vocoder type of tts task.')
|
|
parser.add_argument(
|
|
'--voc_config', type=str, default=None, help='Config of voc.')
|
|
parser.add_argument(
|
|
'--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
|
|
parser.add_argument(
|
|
"--voc_stat",
|
|
type=str,
|
|
default=None,
|
|
help="mean and standard deviation used to normalize spectrogram when training voc."
|
|
)
|
|
# other
|
|
parser.add_argument(
|
|
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
|
|
|
# ernie sat related
|
|
parser.add_argument(
|
|
"--task_name",
|
|
type=str,
|
|
choices=['edit', 'synthesize'],
|
|
help="task name.")
|
|
parser.add_argument("--wav_path", type=str, help="path of old wav")
|
|
parser.add_argument("--old_str", type=str, help="old string")
|
|
parser.add_argument("--new_str", type=str, help="new string")
|
|
parser.add_argument(
|
|
"--source_lang", type=str, default="en", help="source language")
|
|
parser.add_argument(
|
|
"--target_lang", type=str, default="en", help="target language")
|
|
parser.add_argument(
|
|
"--duration_adjust",
|
|
type=str2bool,
|
|
default=True,
|
|
help="whether to adjust duration.")
|
|
parser.add_argument("--output_name", type=str, default="output.wav")
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
|
|
if args.ngpu == 0:
|
|
paddle.set_device("cpu")
|
|
elif args.ngpu > 0:
|
|
paddle.set_device("gpu")
|
|
else:
|
|
print("ngpu should >= 0 !")
|
|
|
|
# evaluate(args)
|
|
with open(args.erniesat_config) as f:
|
|
erniesat_config = CfgNode(yaml.safe_load(f))
|
|
old_str = args.old_str
|
|
new_str = args.new_str
|
|
|
|
# convert Chinese characters to pinyin
|
|
if args.source_lang == 'zh':
|
|
old_str = pypinyin.lazy_pinyin(
|
|
old_str,
|
|
neutral_tone_with_five=True,
|
|
style=pypinyin.Style.TONE3,
|
|
tone_sandhi=True)
|
|
old_str = ' '.join(old_str)
|
|
if args.target_lang == 'zh':
|
|
new_str = pypinyin.lazy_pinyin(
|
|
new_str,
|
|
neutral_tone_with_five=True,
|
|
style=pypinyin.Style.TONE3,
|
|
tone_sandhi=True)
|
|
new_str = ' '.join(new_str)
|
|
|
|
if args.task_name == 'edit':
|
|
new_str = new_str
|
|
elif args.task_name == 'synthesize':
|
|
new_str = old_str + ' ' + new_str
|
|
else:
|
|
new_str = old_str + ' ' + new_str
|
|
|
|
# Extractor
|
|
mel_extractor = LogMelFBank(
|
|
sr=erniesat_config.fs,
|
|
n_fft=erniesat_config.n_fft,
|
|
hop_length=erniesat_config.n_shift,
|
|
win_length=erniesat_config.win_length,
|
|
window=erniesat_config.window,
|
|
n_mels=erniesat_config.n_mels,
|
|
fmin=erniesat_config.fmin,
|
|
fmax=erniesat_config.fmax)
|
|
|
|
collate_fn = build_erniesat_collate_fn(
|
|
mlm_prob=erniesat_config.mlm_prob,
|
|
mean_phn_span=erniesat_config.mean_phn_span,
|
|
seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm',
|
|
text_masking=False)
|
|
|
|
vocab_phones = {}
|
|
|
|
with open(args.phones_dict, 'rt') as f:
|
|
phn_id = [line.strip().split() for line in f.readlines()]
|
|
for phn, id in phn_id:
|
|
vocab_phones[phn] = int(id)
|
|
|
|
# ernie sat model
|
|
erniesat_inference = get_am_inference(
|
|
am='erniesat_dataset',
|
|
am_config=erniesat_config,
|
|
am_ckpt=args.erniesat_ckpt,
|
|
am_stat=args.erniesat_stat,
|
|
phones_dict=args.phones_dict)
|
|
|
|
with open(args.voc_config) as f:
|
|
voc_config = CfgNode(yaml.safe_load(f))
|
|
|
|
# vocoder
|
|
voc_inference = get_voc_inference(
|
|
voc=args.voc,
|
|
voc_config=voc_config,
|
|
voc_ckpt=args.voc_ckpt,
|
|
voc_stat=args.voc_stat)
|
|
|
|
erniesat_stat = args.erniesat_stat
|
|
|
|
wav_dict = get_wav(
|
|
wav_path=args.wav_path,
|
|
source_lang=args.source_lang,
|
|
target_lang=args.target_lang,
|
|
old_str=old_str,
|
|
new_str=new_str,
|
|
duration_adjust=args.duration_adjust,
|
|
fs=erniesat_config.fs,
|
|
n_shift=erniesat_config.n_shift,
|
|
task_name=args.task_name)
|
|
|
|
sf.write(
|
|
args.output_name, wav_dict['output'], samplerate=erniesat_config.fs)
|
|
print(
|
|
f"\033[1;32;m Generated audio saved into {args.output_name} ! \033[0m")
|