update config, test=doc

pull/2117/head
TianYuan 2 years ago
parent 97965f4c37
commit 9d4161ce5f

@ -79,13 +79,13 @@ grad_clip: 1.0
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 600
num_snapshots: 5
max_epoch: 1500
num_snapshots: 50
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086
seed: 0
token_list:
- <blank>

@ -79,13 +79,13 @@ grad_clip: 1.0
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 300
num_snapshots: 5
max_epoch: 700
num_snapshots: 50
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086
seed: 0
token_list:
- <blank>

@ -79,8 +79,8 @@ grad_clip: 1.0
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 600
num_snapshots: 5
max_epoch: 1500
num_snapshots: 50
###########################################################
# OTHER SETTING #

@ -0,0 +1,386 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from pathlib import Path
import librosa
import numpy as np
import pypinyin
from praatio import textgrid
from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name
from paddlespeech.t2s.exps.ernie_sat.utils import get_dict
DICT_EN = 'tools/aligner/cmudict-0.7b'
DICT_ZH = 'tools/aligner/simple.lexicon'
MODEL_DIR_EN = 'tools/aligner/vctk_model.zip'
MODEL_DIR_ZH = 'tools/aligner/aishell3_model.zip'
MFA_PATH = 'tools/montreal-forced-aligner/bin'
os.environ['PATH'] = MFA_PATH + '/:' + os.environ['PATH']
def _get_max_idx(dic):
return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1]
def _readtg(tg_path: str, lang: str='en', fs: int=24000, n_shift: int=300):
alignment = textgrid.openTextgrid(tg_path, includeEmptyIntervals=True)
phones = []
ends = []
words = []
for interval in alignment.tierDict['words'].entryList:
word = interval.label
if word:
words.append(word)
for interval in alignment.tierDict['phones'].entryList:
phone = interval.label
phones.append(phone)
ends.append(interval.end)
frame_pos = librosa.time_to_frames(ends, sr=fs, hop_length=n_shift)
durations = np.diff(frame_pos, prepend=0)
assert len(durations) == len(phones)
# merge '' and sp in the end
if phones[-1] == '' and len(phones) > 1 and phones[-2] == 'sp':
phones = phones[:-1]
durations[-2] += durations[-1]
durations = durations[:-1]
# replace ' and 'sil' with 'sp'
phones = ['sp' if (phn == '' or phn == 'sil') else phn for phn in phones]
if lang == 'en':
DICT = DICT_EN
elif lang == 'zh':
DICT = DICT_ZH
word2phns_dict = get_dict(DICT)
phn2word_dict = []
for word in words:
if lang == 'en':
word = word.upper()
phn2word_dict.append([word2phns_dict[word].split(), word])
non_sp_idx = 0
word_idx = 0
i = 0
word2phns = {}
while i < len(phones):
phn = phones[i]
if phn == 'sp':
word2phns[str(word_idx) + '_sp'] = ['sp']
i += 1
else:
phns, word = phn2word_dict[non_sp_idx]
word2phns[str(word_idx) + '_' + word] = phns
non_sp_idx += 1
i += len(phns)
word_idx += 1
sum_phn = sum(len(word2phns[k]) for k in word2phns)
assert sum_phn == len(phones)
results = ''
for (p, d) in zip(phones, durations):
results += p + ' ' + str(d) + ' '
return results.strip(), word2phns
def alignment(wav_path: str,
text: str,
fs: int=24000,
lang='en',
n_shift: int=300):
wav_name = os.path.basename(wav_path)
utt = wav_name.split('.')[0]
# prepare data for MFA
tmp_name = get_tmp_name(text=text)
tmpbase = './tmp_dir/' + tmp_name
tmpbase = Path(tmpbase)
tmpbase.mkdir(parents=True, exist_ok=True)
print("tmp_name in alignment:",tmp_name)
shutil.copyfile(wav_path, tmpbase / wav_name)
txt_name = utt + '.txt'
txt_path = tmpbase / txt_name
with open(txt_path, 'w') as wf:
wf.write(text + '\n')
# MFA
if lang == 'en':
DICT = DICT_EN
MODEL_DIR = MODEL_DIR_EN
elif lang == 'zh':
DICT = DICT_ZH
MODEL_DIR = MODEL_DIR_ZH
else:
print('please input right lang!!')
CMD = 'mfa_align' + ' ' + str(
tmpbase) + ' ' + DICT + ' ' + MODEL_DIR + ' ' + str(tmpbase)
os.system(CMD)
tg_path = str(tmpbase) + '/' + tmp_name + '/' + utt + '.TextGrid'
phn_dur, word2phns = _readtg(tg_path, lang=lang)
phn_dur = phn_dur.split()
phns = phn_dur[::2]
durs = phn_dur[1::2]
durs = [int(d) for d in durs]
assert len(phns) == len(durs)
return phns, durs, word2phns
def words2phns(text: str, lang='en'):
'''
Args:
text (str):
input text.
eg: for that reason cover is impossible to be given.
lang (str):
'en' or 'zh'
Returns:
List[str]: phones of input text.
eg:
['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0',
'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1',
'G', 'IH1', 'V', 'AH0', 'N']
Dict(str, str): key - idx_word
value - phones
eg:
{'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'],
'2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'],
'5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'],
'6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']}
'''
text = text.strip()
words = []
for pun in [
',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u'',
u'', u'', u'', u'', u'', u'', u''
]:
text = text.replace(pun, ' ')
for wrd in text.split():
if (wrd[-1] == '-'):
wrd = wrd[:-1]
if (wrd[0] == "'"):
wrd = wrd[1:]
if wrd:
words.append(wrd)
if lang == 'en':
dictfile = DICT_EN
elif lang == 'zh':
dictfile = DICT_ZH
else:
print('please input right lang!!')
word2phns_dict = get_dict(dictfile)
ds = word2phns_dict.keys()
phns = []
wrd2phns = {}
for index, wrd in enumerate(words):
if lang == 'en':
wrd = wrd.upper()
if (wrd not in ds):
wrd2phns[str(index) + '_' + wrd] = 'spn'
phns.extend('spn')
else:
wrd2phns[str(index) + '_' + wrd] = word2phns_dict[wrd].split()
phns.extend(word2phns_dict[wrd].split())
return phns, wrd2phns
def get_phns_spans(wav_path: str,
old_str: str='',
new_str: str='',
source_lang: str='en',
target_lang: str='en',
fs: int=24000,
n_shift: int=300):
is_append = (old_str == new_str[:len(old_str)])
old_phns, mfa_start, mfa_end = [], [], []
# source
lang = source_lang
phn, dur, w2p = alignment(
wav_path=wav_path, text=old_str, lang=lang, fs=fs, n_shift=n_shift)
new_d_cumsum = np.pad(np.array(dur).cumsum(0), (1, 0), 'constant').tolist()
mfa_start = new_d_cumsum[:-1]
mfa_end = new_d_cumsum[1:]
old_phns = phn
# target
if is_append and (source_lang != target_lang):
cross_lingual_clone = True
else:
cross_lingual_clone = False
if cross_lingual_clone:
str_origin = new_str[:len(old_str)]
str_append = new_str[len(old_str):]
if target_lang == 'zh':
phns_origin, origin_w2p = words2phns(str_origin, lang='en')
phns_append, append_w2p_tmp = words2phns(str_append, lang='zh')
elif target_lang == 'en':
# 原始句子
phns_origin, origin_w2p = words2phns(str_origin, lang='zh')
# clone 句子
phns_append, append_w2p_tmp = words2phns(str_append, lang='en')
else:
assert target_lang == 'zh' or target_lang == 'en', \
'cloning is not support for this language, please check it.'
new_phns = phns_origin + phns_append
append_w2p = {}
length = len(origin_w2p)
for key, value in append_w2p_tmp.items():
idx, wrd = key.split('_')
append_w2p[str(int(idx) + length) + '_' + wrd] = value
new_w2p = origin_w2p.copy()
new_w2p.update(append_w2p)
else:
if source_lang == target_lang:
new_phns, new_w2p = words2phns(new_str, lang=source_lang)
else:
assert source_lang == target_lang, \
'source language is not same with target language...'
span_to_repl = [0, len(old_phns) - 1]
span_to_add = [0, len(new_phns) - 1]
left_idx = 0
new_phns_left = []
sp_count = 0
# find the left different index
# 因为可能 align 时候的 words2phns 和直接 words2phns, 前者会有 sp
for key in w2p.keys():
idx, wrd = key.split('_')
if wrd == 'sp':
sp_count += 1
new_phns_left.append('sp')
else:
idx = str(int(idx) - sp_count)
if idx + '_' + wrd in new_w2p:
# 是 new_str phn 序列的 index
left_idx += len(new_w2p[idx + '_' + wrd])
# old phn 序列
new_phns_left.extend(w2p[key])
else:
span_to_repl[0] = len(new_phns_left)
span_to_add[0] = len(new_phns_left)
break
# reverse w2p and new_w2p
right_idx = 0
new_phns_right = []
sp_count = 0
w2p_max_idx = _get_max_idx(w2p)
new_w2p_max_idx = _get_max_idx(new_w2p)
new_phns_mid = []
if is_append:
new_phns_right = []
new_phns_mid = new_phns[left_idx:]
span_to_repl[0] = len(new_phns_left)
span_to_add[0] = len(new_phns_left)
span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
span_to_repl[1] = len(old_phns) - len(new_phns_right)
# speech edit
else:
for key in list(w2p.keys())[::-1]:
idx, wrd = key.split('_')
if wrd == 'sp':
sp_count += 1
new_phns_right = ['sp'] + new_phns_right
else:
idx = str(new_w2p_max_idx - (w2p_max_idx - int(idx) - sp_count))
if idx + '_' + wrd in new_w2p:
right_idx -= len(new_w2p[idx + '_' + wrd])
new_phns_right = w2p[key] + new_phns_right
else:
span_to_repl[1] = len(old_phns) - len(new_phns_right)
new_phns_mid = new_phns[left_idx:right_idx]
span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
if len(new_phns_mid) == 0:
span_to_add[1] = min(span_to_add[1] + 1, len(new_phns))
span_to_add[0] = max(0, span_to_add[0] - 1)
span_to_repl[0] = max(0, span_to_repl[0] - 1)
span_to_repl[1] = min(span_to_repl[1] + 1,
len(old_phns))
break
new_phns = new_phns_left + new_phns_mid + new_phns_right
'''
For that reason cover should not be given.
For that reason cover is impossible to be given.
span_to_repl: [17, 23] "should not"
span_to_add: [17, 30] "is impossible to"
'''
outs = {}
outs['mfa_start'] = mfa_start
outs['mfa_end'] = mfa_end
outs['old_phns'] = old_phns
outs['new_phns'] = new_phns
outs['span_to_repl'] = span_to_repl
outs['span_to_add'] = span_to_add
return outs
if __name__ == '__main__':
text = "For that reason cover should not be given."
phn, dur, word2phns = alignment("exp/p243_313.wav", text, lang='en')
print(phn, dur)
print(word2phns)
print("---------------------------------")
# 这里可以用我们的中文前端得到 pinyin 序列
text_zh = "卡尔普陪外孙玩滑梯。"
text_zh = pypinyin.lazy_pinyin(
text_zh,
neutral_tone_with_five=True,
style=pypinyin.Style.TONE3,
tone_sandhi=True)
text_zh = " ".join(text_zh)
phn, dur, word2phns = alignment("exp/000001.wav", text_zh, lang='zh')
print(phn, dur)
print(word2phns)
print("---------------------------------")
phns, wrd2phns = words2phns(text, lang='en')
print("phns:", phns)
print("wrd2phns:", wrd2phns)
print("---------------------------------")
phns, wrd2phns = words2phns(text_zh, lang='zh')
print("phns:", phns)
print("wrd2phns:", wrd2phns)
print("---------------------------------")
outs = get_phns_spans(
wav_path="exp/p243_313.wav",
old_str="For that reason cover should not be given.",
new_str="for that reason cover is impossible to be given.")
mfa_start = outs["mfa_start"]
mfa_end = outs["mfa_end"]
old_phns = outs["old_phns"]
new_phns = outs["new_phns"]
span_to_repl = outs["span_to_repl"]
span_to_add = outs["span_to_add"]
print("mfa_start:", mfa_start)
print("mfa_end:", mfa_end)
print("old_phns:", old_phns)
print("new_phns:", new_phns)
print("span_to_repl:", span_to_repl)
print("span_to_add:", span_to_add)
print("---------------------------------")

@ -0,0 +1,346 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import librosa
import numpy as np
import soundfile as sf
from paddlespeech.t2s.exps.ernie_sat.align import get_phns_spans
from paddlespeech.t2s.exps.ernie_sat.utils import eval_durs
from paddlespeech.t2s.exps.ernie_sat.utils import get_dur_adj_factor
from paddlespeech.t2s.exps.ernie_sat.utils import get_span_bdy
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.exps.syn_utils import norm
from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name
def _p2id(self, phonemes: List[str]) -> np.ndarray:
# replace unk phone with sp
phonemes = [
phn if phn in vocab_phones else "sp" for phn in phonemes
]
phone_ids = [vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64)
def prep_feats_with_dur(wav_path: str,
old_str: str='',
new_str: str='',
source_lang: str='en',
target_lang: str='en',
duration_adjust: bool=True,
fs: int=24000,
n_shift: int=300):
'''
Returns:
np.ndarray: new wav, replace the part to be edited in original wav with 0
List[str]: new phones
List[float]: mfa start of new wav
List[float]: mfa end of new wav
List[int]: masked mel boundary of original wav
List[int]: masked mel boundary of new wav
'''
wav_org, _ = librosa.load(wav_path, sr=fs)
phns_spans_outs = get_phns_spans(
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
source_lang=source_lang,
target_lang=target_lang,
fs=fs,
n_shift=n_shift)
mfa_start = phns_spans_outs["mfa_start"]
mfa_end = phns_spans_outs["mfa_end"]
old_phns = phns_spans_outs["old_phns"]
new_phns = phns_spans_outs["new_phns"]
span_to_repl = phns_spans_outs["span_to_repl"]
span_to_add = phns_spans_outs["span_to_add"]
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
if target_lang in {'en', 'zh'}:
old_durs = eval_durs(old_phns, target_lang=source_lang)
else:
assert target_lang in {'en', 'zh'}, \
"calculate duration_predict is not support for this language..."
orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)]
if duration_adjust:
d_factor = get_dur_adj_factor(
orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
d_factor = d_factor * 1.25
else:
d_factor = 1
if target_lang in {'en', 'zh'}:
new_durs = eval_durs(new_phns, target_lang=target_lang)
else:
assert target_lang == "zh" or target_lang == "en", \
"calculate duration_predict is not support for this language..."
# duration 要是整数
new_durs_adjusted = [int(np.ceil(d_factor * i)) for i in new_durs]
new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]])
old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]])
dur_offset = new_span_dur_sum - old_span_dur_sum
new_mfa_start = mfa_start[:span_to_repl[0]]
new_mfa_end = mfa_end[:span_to_repl[0]]
for dur in new_durs_adjusted[span_to_add[0]:span_to_add[1]]:
if len(new_mfa_end) == 0:
new_mfa_start.append(0)
new_mfa_end.append(dur)
else:
new_mfa_start.append(new_mfa_end[-1])
new_mfa_end.append(new_mfa_end[-1] + dur)
new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]]
new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]]
# 3. get new wav
# 在原始句子后拼接
if span_to_repl[0] >= len(mfa_start):
wav_left_idx = len(wav_org)
wav_right_idx = wav_left_idx
# 在原始句子中间替换
else:
wav_left_idx = int(np.floor(mfa_start[span_to_repl[0]] * n_shift))
wav_right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * n_shift))
blank_wav = np.zeros(
(int(np.ceil(new_span_dur_sum * n_shift)), ), dtype=wav_org.dtype)
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
new_wav = np.concatenate(
[wav_org[:wav_left_idx], blank_wav, wav_org[wav_right_idx:]])
# 音频是正常遮住了
sf.write(str("new_wav.wav"), new_wav, samplerate=fs)
# 4. get old and new mel span to be mask
old_span_bdy = get_span_bdy(
mfa_start=mfa_start, mfa_end=mfa_end, span_to_repl=span_to_repl)
new_span_bdy = get_span_bdy(
mfa_start=new_mfa_start, mfa_end=new_mfa_end, span_to_repl=span_to_add)
# old_span_bdy, new_span_bdy 是帧级别的范围
outs = {}
outs['new_wav'] = new_wav
outs['new_phns'] = new_phns
outs['new_mfa_start'] = new_mfa_start
outs['new_mfa_end'] = new_mfa_end
outs['old_span_bdy'] = old_span_bdy
outs['new_span_bdy'] = new_span_bdy
return outs
def prep_feats(wav_path: str,
old_str: str='',
new_str: str='',
source_lang: str='en',
target_lang: str='en',
duration_adjust: bool=True,
fs: int=24000,
n_shift: int=300):
outs = prep_feats_with_dur(
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
source_lang=source_lang,
target_lang=target_lang,
duration_adjust=duration_adjust,
fs=fs,
n_shift=n_shift)
wav_name = os.path.basename(wav_path)
utt_id = wav_name.split('.')[0]
wav = outs['new_wav']
phns = outs['new_phns']
mfa_start = outs['new_mfa_start']
mfa_end = outs['new_mfa_end']
old_span_bdy = outs['old_span_bdy']
new_span_bdy = outs['new_span_bdy']
span_bdy = np.array(new_span_bdy)
text = _p2id(phns)
mel = mel_extractor.get_log_mel_fbank(wav)
erniesat_mean, erniesat_std = np.load(erniesat_stat)
normed_mel = norm(mel, erniesat_mean, erniesat_std)
tmp_name = get_tmp_name(text=old_str)
tmpbase = './tmp_dir/' + tmp_name
tmpbase = Path(tmpbase)
tmpbase.mkdir(parents=True, exist_ok=True)
print("tmp_name in synthesize_e2e:",tmp_name)
mel_path = tmpbase / 'mel.npy'
print("mel_path:",mel_path)
np.save(mel_path, logmel)
durations = [e - s for e, s in zip(mfa_end, mfa_start)]
datum={
"utt_id": utt_id,
"spk_id": 0,
"text": text,
"text_lengths": len(text),
"speech_lengths": 115,
"durations": durations,
"speech": mel_path,
"align_start": mfa_start,
"align_end": mfa_end,
"span_bdy": span_bdy
}
batch = collate_fn([datum])
print("batch:",batch)
return batch, old_span_bdy, new_span_bdy
def decode_with_model(mlm_model: nn.Layer,
collate_fn,
wav_path: str,
old_str: str='',
new_str: str='',
source_lang: str='en',
target_lang: str='en',
use_teacher_forcing: bool=False,
duration_adjust: bool=True,
fs: int=24000,
n_shift: int=300,
token_list: List[str]=[]):
batch, old_span_bdy, new_span_bdy = prep_feats(
source_lang=source_lang,
target_lang=target_lang,
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
duration_adjust=duration_adjust,
fs=fs,
n_shift=n_shift,
token_list=token_list)
feats = collate_fn(batch)[1]
if 'text_masked_pos' in feats.keys():
feats.pop('text_masked_pos')
output = mlm_model.inference(
text=feats['text'],
speech=feats['speech'],
masked_pos=feats['masked_pos'],
speech_mask=feats['speech_mask'],
text_mask=feats['text_mask'],
speech_seg_pos=feats['speech_seg_pos'],
text_seg_pos=feats['text_seg_pos'],
span_bdy=new_span_bdy,
use_teacher_forcing=use_teacher_forcing)
# 拼接音频
output_feat = paddle.concat(x=output, axis=0)
wav_org, _ = librosa.load(wav_path, sr=fs)
return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
if __name__ == '__main__':
fs = 24000
n_shift = 300
wav_path = "exp/p243_313.wav"
old_str = "For that reason cover should not be given."
# for edit
# new_str = "for that reason cover is impossible to be given."
# for synthesize
append_str = "do you love me i love you so much"
new_str = old_str + append_str
'''
outs = prep_feats_with_dur(
wav_path=wav_path,
old_str=old_str,
new_str=new_str,
fs=fs,
n_shift=n_shift)
new_wav = outs['new_wav']
new_phns = outs['new_phns']
new_mfa_start = outs['new_mfa_start']
new_mfa_end = outs['new_mfa_end']
old_span_bdy = outs['old_span_bdy']
new_span_bdy = outs['new_span_bdy']
print("---------------------------------")
print("new_wav:", new_wav)
print("new_phns:", new_phns)
print("new_mfa_start:", new_mfa_start)
print("new_mfa_end:", new_mfa_end)
print("old_span_bdy:", old_span_bdy)
print("new_span_bdy:", new_span_bdy)
print("---------------------------------")
'''
erniesat_config = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/local/default.yaml"
with open(erniesat_config) as f:
erniesat_config = CfgNode(yaml.safe_load(f))
erniesat_stat = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/train/speech_stats.npy"
# Extractor
mel_extractor = LogMelFBank(
sr=erniesat_config.fs,
n_fft=erniesat_config.n_fft,
hop_length=erniesat_config.n_shift,
win_length=erniesat_config.win_length,
window=erniesat_config.window,
n_mels=erniesat_config.n_mels,
fmin=erniesat_config.fmin,
fmax=erniesat_config.fmax)
collate_fn = build_erniesat_collate_fn(
mlm_prob=erniesat_config.mlm_prob,
mean_phn_span=erniesat_config.mean_phn_span,
seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm',
text_masking=False)
phones_dict='/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/phone_id_map.txt'
vocab_phones = {}
with open(phones_dict, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
vocab_phones[phn] = int(id)
prep_feats(wav_path=wav_path,
old_str=old_str,
new_str=new_str,
fs=fs,
n_shift=n_shift)

@ -0,0 +1,216 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Dict
from typing import List
from typing import Union
import os
import numpy as np
import paddle
import yaml
from yacs.config import CfgNode
import hashlib
from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
def _get_user():
return os.path.expanduser('~').split('/')[-1]
def str2md5(string):
md5_val = hashlib.md5(string.encode('utf8')).hexdigest()
return md5_val
def get_tmp_name(text:str):
return _get_user() + '_' + str(os.getpid()) + '_' + str2md5(text)
def get_dict(dictfile: str):
word2phns_dict = {}
with open(dictfile, 'r') as fid:
for line in fid:
line_lst = line.split()
word, phn_lst = line_lst[0], line.split()[1:]
if word not in word2phns_dict.keys():
word2phns_dict[word] = ' '.join(phn_lst)
return word2phns_dict
# 获取需要被 mask 的 mel 帧的范围
def get_span_bdy(mfa_start: List[float],
mfa_end: List[float],
span_to_repl: List[List[int]]):
if span_to_repl[0] >= len(mfa_start):
span_bdy = [mfa_end[-1], mfa_end[-1]]
else:
span_bdy = [mfa_start[span_to_repl[0]], mfa_end[span_to_repl[1] - 1]]
return span_bdy
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
def get_dur_adj_factor(orig_dur: List[int],
pred_dur: List[int],
phns: List[str]):
length = 0
factor_list = []
for orig, pred, phn in zip(orig_dur, pred_dur, phns):
if pred == 0 or phn == 'sp':
continue
else:
factor_list.append(orig / pred)
factor_list = np.array(factor_list)
factor_list.sort()
if len(factor_list) < 5:
return 1
length = 2
avg = np.average(factor_list[length:-length])
return avg
def read_2col_text(path: Union[Path, str]) -> Dict[str, str]:
"""Read a text file having 2 column as dict object.
Examples:
wav.scp:
key1 /some/path/a.wav
key2 /some/path/b.wav
>>> read_2col_text('wav.scp')
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
"""
data = {}
with Path(path).open("r", encoding="utf-8") as f:
for linenum, line in enumerate(f, 1):
sps = line.rstrip().split(maxsplit=1)
if len(sps) == 1:
k, v = sps[0], ""
else:
k, v = sps
if k in data:
raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
data[k] = v
return data
def load_num_sequence_text(path: Union[Path, str], loader_type: str="csv_int"
) -> Dict[str, List[Union[float, int]]]:
"""Read a text file indicating sequences of number
Examples:
key1 1 2 3
key2 34 5 6
>>> d = load_num_sequence_text('text')
>>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3]))
"""
if loader_type == "text_int":
delimiter = " "
dtype = int
elif loader_type == "text_float":
delimiter = " "
dtype = float
elif loader_type == "csv_int":
delimiter = ","
dtype = int
elif loader_type == "csv_float":
delimiter = ","
dtype = float
else:
raise ValueError(f"Not supported loader_type={loader_type}")
# path looks like:
# utta 1,0
# uttb 3,4,5
# -> return {'utta': np.ndarray([1, 0]),
# 'uttb': np.ndarray([3, 4, 5])}
d = read_2column_text(path)
# Using for-loop instead of dict-comprehension for debuggability
retval = {}
for k, v in d.items():
try:
retval[k] = [dtype(i) for i in v.split(delimiter)]
except TypeError:
print(f'Error happened with path="{path}", id="{k}", value="{v}"')
raise
return retval
def is_chinese(ch):
if u'\u4e00' <= ch <= u'\u9fff':
return True
else:
return False
def get_voc_out(mel):
# vocoder
args = parse_args()
with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f))
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
with paddle.no_grad():
wav = voc_inference(mel)
return np.squeeze(wav)
def eval_durs(phns, target_lang: str='zh', fs: int=24000, n_shift: int=300):
if target_lang == 'en':
am = "fastspeech2_ljspeech"
am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
elif target_lang == 'zh':
am = "fastspeech2_csmsc"
am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
# Init body.
with open(am_config) as f:
am_config = CfgNode(yaml.safe_load(f))
am_inference, am = get_am_inference(
am=am,
am_config=am_config,
am_ckpt=am_ckpt,
am_stat=am_stat,
phones_dict=phones_dict,
return_am=True)
vocab_phones = {}
with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
for tone, id in phn_id:
vocab_phones[tone] = int(id)
vocab_size = len(vocab_phones)
phonemes = [phn if phn in vocab_phones else "sp" for phn in phns]
phone_ids = [vocab_phones[item] for item in phonemes]
phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64))
_, d_outs, _, _ = am.inference(phone_ids)
d_outs = d_outs.tolist()
return d_outs
Loading…
Cancel
Save