You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
347 lines
11 KiB
347 lines
11 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import librosa
|
|
import numpy as np
|
|
import soundfile as sf
|
|
|
|
from paddlespeech.t2s.exps.ernie_sat.align import get_phns_spans
|
|
from paddlespeech.t2s.exps.ernie_sat.utils import eval_durs
|
|
from paddlespeech.t2s.exps.ernie_sat.utils import get_dur_adj_factor
|
|
from paddlespeech.t2s.exps.ernie_sat.utils import get_span_bdy
|
|
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
|
|
from paddlespeech.t2s.exps.syn_utils import get_frontend
|
|
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
|
|
from paddlespeech.t2s.exps.syn_utils import norm
|
|
from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _p2id(self, phonemes: List[str]) -> np.ndarray:
|
|
# replace unk phone with sp
|
|
phonemes = [
|
|
phn if phn in vocab_phones else "sp" for phn in phonemes
|
|
]
|
|
phone_ids = [vocab_phones[item] for item in phonemes]
|
|
return np.array(phone_ids, np.int64)
|
|
|
|
|
|
|
|
def prep_feats_with_dur(wav_path: str,
|
|
old_str: str='',
|
|
new_str: str='',
|
|
source_lang: str='en',
|
|
target_lang: str='en',
|
|
duration_adjust: bool=True,
|
|
fs: int=24000,
|
|
n_shift: int=300):
|
|
'''
|
|
Returns:
|
|
np.ndarray: new wav, replace the part to be edited in original wav with 0
|
|
List[str]: new phones
|
|
List[float]: mfa start of new wav
|
|
List[float]: mfa end of new wav
|
|
List[int]: masked mel boundary of original wav
|
|
List[int]: masked mel boundary of new wav
|
|
'''
|
|
wav_org, _ = librosa.load(wav_path, sr=fs)
|
|
phns_spans_outs = get_phns_spans(
|
|
wav_path=wav_path,
|
|
old_str=old_str,
|
|
new_str=new_str,
|
|
source_lang=source_lang,
|
|
target_lang=target_lang,
|
|
fs=fs,
|
|
n_shift=n_shift)
|
|
|
|
mfa_start = phns_spans_outs["mfa_start"]
|
|
mfa_end = phns_spans_outs["mfa_end"]
|
|
old_phns = phns_spans_outs["old_phns"]
|
|
new_phns = phns_spans_outs["new_phns"]
|
|
span_to_repl = phns_spans_outs["span_to_repl"]
|
|
span_to_add = phns_spans_outs["span_to_add"]
|
|
|
|
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
|
|
if target_lang in {'en', 'zh'}:
|
|
old_durs = eval_durs(old_phns, target_lang=source_lang)
|
|
else:
|
|
assert target_lang in {'en', 'zh'}, \
|
|
"calculate duration_predict is not support for this language..."
|
|
|
|
orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)]
|
|
|
|
if duration_adjust:
|
|
d_factor = get_dur_adj_factor(
|
|
orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
|
|
d_factor = d_factor * 1.25
|
|
else:
|
|
d_factor = 1
|
|
|
|
if target_lang in {'en', 'zh'}:
|
|
new_durs = eval_durs(new_phns, target_lang=target_lang)
|
|
else:
|
|
assert target_lang == "zh" or target_lang == "en", \
|
|
"calculate duration_predict is not support for this language..."
|
|
|
|
# duration 要是整数
|
|
new_durs_adjusted = [int(np.ceil(d_factor * i)) for i in new_durs]
|
|
|
|
new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]])
|
|
old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]])
|
|
dur_offset = new_span_dur_sum - old_span_dur_sum
|
|
new_mfa_start = mfa_start[:span_to_repl[0]]
|
|
new_mfa_end = mfa_end[:span_to_repl[0]]
|
|
|
|
for dur in new_durs_adjusted[span_to_add[0]:span_to_add[1]]:
|
|
if len(new_mfa_end) == 0:
|
|
new_mfa_start.append(0)
|
|
new_mfa_end.append(dur)
|
|
else:
|
|
new_mfa_start.append(new_mfa_end[-1])
|
|
new_mfa_end.append(new_mfa_end[-1] + dur)
|
|
|
|
new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]]
|
|
new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]]
|
|
|
|
# 3. get new wav
|
|
# 在原始句子后拼接
|
|
if span_to_repl[0] >= len(mfa_start):
|
|
wav_left_idx = len(wav_org)
|
|
wav_right_idx = wav_left_idx
|
|
# 在原始句子中间替换
|
|
else:
|
|
wav_left_idx = int(np.floor(mfa_start[span_to_repl[0]] * n_shift))
|
|
wav_right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * n_shift))
|
|
blank_wav = np.zeros(
|
|
(int(np.ceil(new_span_dur_sum * n_shift)), ), dtype=wav_org.dtype)
|
|
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
|
|
new_wav = np.concatenate(
|
|
[wav_org[:wav_left_idx], blank_wav, wav_org[wav_right_idx:]])
|
|
|
|
# 音频是正常遮住了
|
|
sf.write(str("new_wav.wav"), new_wav, samplerate=fs)
|
|
|
|
# 4. get old and new mel span to be mask
|
|
old_span_bdy = get_span_bdy(
|
|
mfa_start=mfa_start, mfa_end=mfa_end, span_to_repl=span_to_repl)
|
|
|
|
new_span_bdy = get_span_bdy(
|
|
mfa_start=new_mfa_start, mfa_end=new_mfa_end, span_to_repl=span_to_add)
|
|
|
|
# old_span_bdy, new_span_bdy 是帧级别的范围
|
|
outs = {}
|
|
outs['new_wav'] = new_wav
|
|
outs['new_phns'] = new_phns
|
|
outs['new_mfa_start'] = new_mfa_start
|
|
outs['new_mfa_end'] = new_mfa_end
|
|
outs['old_span_bdy'] = old_span_bdy
|
|
outs['new_span_bdy'] = new_span_bdy
|
|
return outs
|
|
|
|
|
|
|
|
|
|
def prep_feats(wav_path: str,
|
|
old_str: str='',
|
|
new_str: str='',
|
|
source_lang: str='en',
|
|
target_lang: str='en',
|
|
duration_adjust: bool=True,
|
|
fs: int=24000,
|
|
n_shift: int=300):
|
|
|
|
outs = prep_feats_with_dur(
|
|
wav_path=wav_path,
|
|
old_str=old_str,
|
|
new_str=new_str,
|
|
source_lang=source_lang,
|
|
target_lang=target_lang,
|
|
duration_adjust=duration_adjust,
|
|
fs=fs,
|
|
n_shift=n_shift)
|
|
|
|
wav_name = os.path.basename(wav_path)
|
|
utt_id = wav_name.split('.')[0]
|
|
|
|
wav = outs['new_wav']
|
|
phns = outs['new_phns']
|
|
mfa_start = outs['new_mfa_start']
|
|
mfa_end = outs['new_mfa_end']
|
|
old_span_bdy = outs['old_span_bdy']
|
|
new_span_bdy = outs['new_span_bdy']
|
|
span_bdy = np.array(new_span_bdy)
|
|
|
|
text = _p2id(phns)
|
|
mel = mel_extractor.get_log_mel_fbank(wav)
|
|
erniesat_mean, erniesat_std = np.load(erniesat_stat)
|
|
normed_mel = norm(mel, erniesat_mean, erniesat_std)
|
|
tmp_name = get_tmp_name(text=old_str)
|
|
tmpbase = './tmp_dir/' + tmp_name
|
|
tmpbase = Path(tmpbase)
|
|
tmpbase.mkdir(parents=True, exist_ok=True)
|
|
print("tmp_name in synthesize_e2e:",tmp_name)
|
|
|
|
mel_path = tmpbase / 'mel.npy'
|
|
print("mel_path:",mel_path)
|
|
np.save(mel_path, logmel)
|
|
durations = [e - s for e, s in zip(mfa_end, mfa_start)]
|
|
|
|
datum={
|
|
"utt_id": utt_id,
|
|
"spk_id": 0,
|
|
"text": text,
|
|
"text_lengths": len(text),
|
|
"speech_lengths": 115,
|
|
"durations": durations,
|
|
"speech": mel_path,
|
|
"align_start": mfa_start,
|
|
"align_end": mfa_end,
|
|
"span_bdy": span_bdy
|
|
}
|
|
|
|
batch = collate_fn([datum])
|
|
print("batch:",batch)
|
|
|
|
return batch, old_span_bdy, new_span_bdy
|
|
|
|
|
|
def decode_with_model(mlm_model: nn.Layer,
|
|
collate_fn,
|
|
wav_path: str,
|
|
old_str: str='',
|
|
new_str: str='',
|
|
source_lang: str='en',
|
|
target_lang: str='en',
|
|
use_teacher_forcing: bool=False,
|
|
duration_adjust: bool=True,
|
|
fs: int=24000,
|
|
n_shift: int=300,
|
|
token_list: List[str]=[]):
|
|
batch, old_span_bdy, new_span_bdy = prep_feats(
|
|
source_lang=source_lang,
|
|
target_lang=target_lang,
|
|
wav_path=wav_path,
|
|
old_str=old_str,
|
|
new_str=new_str,
|
|
duration_adjust=duration_adjust,
|
|
fs=fs,
|
|
n_shift=n_shift,
|
|
token_list=token_list)
|
|
|
|
|
|
|
|
feats = collate_fn(batch)[1]
|
|
|
|
if 'text_masked_pos' in feats.keys():
|
|
feats.pop('text_masked_pos')
|
|
|
|
output = mlm_model.inference(
|
|
text=feats['text'],
|
|
speech=feats['speech'],
|
|
masked_pos=feats['masked_pos'],
|
|
speech_mask=feats['speech_mask'],
|
|
text_mask=feats['text_mask'],
|
|
speech_seg_pos=feats['speech_seg_pos'],
|
|
text_seg_pos=feats['text_seg_pos'],
|
|
span_bdy=new_span_bdy,
|
|
use_teacher_forcing=use_teacher_forcing)
|
|
|
|
# 拼接音频
|
|
output_feat = paddle.concat(x=output, axis=0)
|
|
wav_org, _ = librosa.load(wav_path, sr=fs)
|
|
return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
|
|
|
|
|
|
if __name__ == '__main__':
|
|
fs = 24000
|
|
n_shift = 300
|
|
wav_path = "exp/p243_313.wav"
|
|
old_str = "For that reason cover should not be given."
|
|
# for edit
|
|
# new_str = "for that reason cover is impossible to be given."
|
|
# for synthesize
|
|
append_str = "do you love me i love you so much"
|
|
new_str = old_str + append_str
|
|
|
|
'''
|
|
outs = prep_feats_with_dur(
|
|
wav_path=wav_path,
|
|
old_str=old_str,
|
|
new_str=new_str,
|
|
fs=fs,
|
|
n_shift=n_shift)
|
|
|
|
new_wav = outs['new_wav']
|
|
new_phns = outs['new_phns']
|
|
new_mfa_start = outs['new_mfa_start']
|
|
new_mfa_end = outs['new_mfa_end']
|
|
old_span_bdy = outs['old_span_bdy']
|
|
new_span_bdy = outs['new_span_bdy']
|
|
|
|
print("---------------------------------")
|
|
|
|
print("new_wav:", new_wav)
|
|
print("new_phns:", new_phns)
|
|
print("new_mfa_start:", new_mfa_start)
|
|
print("new_mfa_end:", new_mfa_end)
|
|
print("old_span_bdy:", old_span_bdy)
|
|
print("new_span_bdy:", new_span_bdy)
|
|
print("---------------------------------")
|
|
'''
|
|
|
|
erniesat_config = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/local/default.yaml"
|
|
|
|
with open(erniesat_config) as f:
|
|
erniesat_config = CfgNode(yaml.safe_load(f))
|
|
|
|
erniesat_stat = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/train/speech_stats.npy"
|
|
|
|
# Extractor
|
|
mel_extractor = LogMelFBank(
|
|
sr=erniesat_config.fs,
|
|
n_fft=erniesat_config.n_fft,
|
|
hop_length=erniesat_config.n_shift,
|
|
win_length=erniesat_config.win_length,
|
|
window=erniesat_config.window,
|
|
n_mels=erniesat_config.n_mels,
|
|
fmin=erniesat_config.fmin,
|
|
fmax=erniesat_config.fmax)
|
|
|
|
|
|
|
|
collate_fn = build_erniesat_collate_fn(
|
|
mlm_prob=erniesat_config.mlm_prob,
|
|
mean_phn_span=erniesat_config.mean_phn_span,
|
|
seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm',
|
|
text_masking=False)
|
|
|
|
phones_dict='/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/phone_id_map.txt'
|
|
vocab_phones = {}
|
|
|
|
with open(phones_dict, 'rt') as f:
|
|
phn_id = [line.strip().split() for line in f.readlines()]
|
|
for phn, id in phn_id:
|
|
vocab_phones[phn] = int(id)
|
|
|
|
prep_feats(wav_path=wav_path,
|
|
old_str=old_str,
|
|
new_str=new_str,
|
|
fs=fs,
|
|
n_shift=n_shift)
|
|
|
|
|
|
|