You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py

347 lines
11 KiB

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import librosa
import numpy as np
import soundfile as sf
from paddlespeech.t2s.exps.ernie_sat.align import get_phns_spans
from paddlespeech.t2s.exps.ernie_sat.utils import eval_durs
from paddlespeech.t2s.exps.ernie_sat.utils import get_dur_adj_factor
from paddlespeech.t2s.exps.ernie_sat.utils import get_span_bdy
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.exps.syn_utils import norm
from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name
def _p2id(self, phonemes: List[str]) -> np.ndarray:
# replace unk phone with sp
phonemes = [
phn if phn in vocab_phones else "sp" for phn in phonemes
]
phone_ids = [vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64)
def prep_feats_with_dur(wav_path: str,
old_str: str='',
new_str: str='',
source_lang: str='en',
target_lang: str='en',
duration_adjust: bool=True,
fs: int=24000,
n_shift: int=300):
'''
Returns:
np.ndarray: new wav, replace the part to be edited in original wav with 0
List[str]: new phones
List[float]: mfa start of new wav
List[float]: mfa end of new wav
List[int]: masked mel boundary of original wav
List[int]: masked mel boundary of new wav
'''
wav_org, _ = librosa.load(wav_path, sr=fs)
phns_spans_outs = get_phns_spans(
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
source_lang=source_lang,
target_lang=target_lang,
fs=fs,
n_shift=n_shift)
mfa_start = phns_spans_outs["mfa_start"]
mfa_end = phns_spans_outs["mfa_end"]
old_phns = phns_spans_outs["old_phns"]
new_phns = phns_spans_outs["new_phns"]
span_to_repl = phns_spans_outs["span_to_repl"]
span_to_add = phns_spans_outs["span_to_add"]
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
if target_lang in {'en', 'zh'}:
old_durs = eval_durs(old_phns, target_lang=source_lang)
else:
assert target_lang in {'en', 'zh'}, \
"calculate duration_predict is not support for this language..."
orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)]
if duration_adjust:
d_factor = get_dur_adj_factor(
orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
d_factor = d_factor * 1.25
else:
d_factor = 1
if target_lang in {'en', 'zh'}:
new_durs = eval_durs(new_phns, target_lang=target_lang)
else:
assert target_lang == "zh" or target_lang == "en", \
"calculate duration_predict is not support for this language..."
# duration 要是整数
new_durs_adjusted = [int(np.ceil(d_factor * i)) for i in new_durs]
new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]])
old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]])
dur_offset = new_span_dur_sum - old_span_dur_sum
new_mfa_start = mfa_start[:span_to_repl[0]]
new_mfa_end = mfa_end[:span_to_repl[0]]
for dur in new_durs_adjusted[span_to_add[0]:span_to_add[1]]:
if len(new_mfa_end) == 0:
new_mfa_start.append(0)
new_mfa_end.append(dur)
else:
new_mfa_start.append(new_mfa_end[-1])
new_mfa_end.append(new_mfa_end[-1] + dur)
new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]]
new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]]
# 3. get new wav
# 在原始句子后拼接
if span_to_repl[0] >= len(mfa_start):
wav_left_idx = len(wav_org)
wav_right_idx = wav_left_idx
# 在原始句子中间替换
else:
wav_left_idx = int(np.floor(mfa_start[span_to_repl[0]] * n_shift))
wav_right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * n_shift))
blank_wav = np.zeros(
(int(np.ceil(new_span_dur_sum * n_shift)), ), dtype=wav_org.dtype)
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
new_wav = np.concatenate(
[wav_org[:wav_left_idx], blank_wav, wav_org[wav_right_idx:]])
# 音频是正常遮住了
sf.write(str("new_wav.wav"), new_wav, samplerate=fs)
# 4. get old and new mel span to be mask
old_span_bdy = get_span_bdy(
mfa_start=mfa_start, mfa_end=mfa_end, span_to_repl=span_to_repl)
new_span_bdy = get_span_bdy(
mfa_start=new_mfa_start, mfa_end=new_mfa_end, span_to_repl=span_to_add)
# old_span_bdy, new_span_bdy 是帧级别的范围
outs = {}
outs['new_wav'] = new_wav
outs['new_phns'] = new_phns
outs['new_mfa_start'] = new_mfa_start
outs['new_mfa_end'] = new_mfa_end
outs['old_span_bdy'] = old_span_bdy
outs['new_span_bdy'] = new_span_bdy
return outs
def prep_feats(wav_path: str,
old_str: str='',
new_str: str='',
source_lang: str='en',
target_lang: str='en',
duration_adjust: bool=True,
fs: int=24000,
n_shift: int=300):
outs = prep_feats_with_dur(
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
source_lang=source_lang,
target_lang=target_lang,
duration_adjust=duration_adjust,
fs=fs,
n_shift=n_shift)
wav_name = os.path.basename(wav_path)
utt_id = wav_name.split('.')[0]
wav = outs['new_wav']
phns = outs['new_phns']
mfa_start = outs['new_mfa_start']
mfa_end = outs['new_mfa_end']
old_span_bdy = outs['old_span_bdy']
new_span_bdy = outs['new_span_bdy']
span_bdy = np.array(new_span_bdy)
text = _p2id(phns)
mel = mel_extractor.get_log_mel_fbank(wav)
erniesat_mean, erniesat_std = np.load(erniesat_stat)
normed_mel = norm(mel, erniesat_mean, erniesat_std)
tmp_name = get_tmp_name(text=old_str)
tmpbase = './tmp_dir/' + tmp_name
tmpbase = Path(tmpbase)
tmpbase.mkdir(parents=True, exist_ok=True)
print("tmp_name in synthesize_e2e:",tmp_name)
mel_path = tmpbase / 'mel.npy'
print("mel_path:",mel_path)
np.save(mel_path, logmel)
durations = [e - s for e, s in zip(mfa_end, mfa_start)]
datum={
"utt_id": utt_id,
"spk_id": 0,
"text": text,
"text_lengths": len(text),
"speech_lengths": 115,
"durations": durations,
"speech": mel_path,
"align_start": mfa_start,
"align_end": mfa_end,
"span_bdy": span_bdy
}
batch = collate_fn([datum])
print("batch:",batch)
return batch, old_span_bdy, new_span_bdy
def decode_with_model(mlm_model: nn.Layer,
collate_fn,
wav_path: str,
old_str: str='',
new_str: str='',
source_lang: str='en',
target_lang: str='en',
use_teacher_forcing: bool=False,
duration_adjust: bool=True,
fs: int=24000,
n_shift: int=300,
token_list: List[str]=[]):
batch, old_span_bdy, new_span_bdy = prep_feats(
source_lang=source_lang,
target_lang=target_lang,
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
duration_adjust=duration_adjust,
fs=fs,
n_shift=n_shift,
token_list=token_list)
feats = collate_fn(batch)[1]
if 'text_masked_pos' in feats.keys():
feats.pop('text_masked_pos')
output = mlm_model.inference(
text=feats['text'],
speech=feats['speech'],
masked_pos=feats['masked_pos'],
speech_mask=feats['speech_mask'],
text_mask=feats['text_mask'],
speech_seg_pos=feats['speech_seg_pos'],
text_seg_pos=feats['text_seg_pos'],
span_bdy=new_span_bdy,
use_teacher_forcing=use_teacher_forcing)
# 拼接音频
output_feat = paddle.concat(x=output, axis=0)
wav_org, _ = librosa.load(wav_path, sr=fs)
return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
if __name__ == '__main__':
fs = 24000
n_shift = 300
wav_path = "exp/p243_313.wav"
old_str = "For that reason cover should not be given."
# for edit
# new_str = "for that reason cover is impossible to be given."
# for synthesize
append_str = "do you love me i love you so much"
new_str = old_str + append_str
'''
outs = prep_feats_with_dur(
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
fs=fs,
n_shift=n_shift)
new_wav = outs['new_wav']
new_phns = outs['new_phns']
new_mfa_start = outs['new_mfa_start']
new_mfa_end = outs['new_mfa_end']
old_span_bdy = outs['old_span_bdy']
new_span_bdy = outs['new_span_bdy']
print("---------------------------------")
print("new_wav:", new_wav)
print("new_phns:", new_phns)
print("new_mfa_start:", new_mfa_start)
print("new_mfa_end:", new_mfa_end)
print("old_span_bdy:", old_span_bdy)
print("new_span_bdy:", new_span_bdy)
print("---------------------------------")
'''
erniesat_config = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/local/default.yaml"
with open(erniesat_config) as f:
erniesat_config = CfgNode(yaml.safe_load(f))
erniesat_stat = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/train/speech_stats.npy"
# Extractor
mel_extractor = LogMelFBank(
sr=erniesat_config.fs,
n_fft=erniesat_config.n_fft,
hop_length=erniesat_config.n_shift,
win_length=erniesat_config.win_length,
window=erniesat_config.window,
n_mels=erniesat_config.n_mels,
fmin=erniesat_config.fmin,
fmax=erniesat_config.fmax)
collate_fn = build_erniesat_collate_fn(
mlm_prob=erniesat_config.mlm_prob,
mean_phn_span=erniesat_config.mean_phn_span,
seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm',
text_masking=False)
phones_dict='/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/phone_id_map.txt'
vocab_phones = {}
with open(phones_dict, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
vocab_phones[phn] = int(id)
prep_feats(wav_path=wav_path,
old_str=old_str,
new_str=new_str,
fs=fs,
n_shift=n_shift)