diff --git a/examples/aishell3/ernie_sat/conf/default.yaml b/examples/aishell3/ernie_sat/conf/default.yaml index d8993e86..fdc767fb 100644 --- a/examples/aishell3/ernie_sat/conf/default.yaml +++ b/examples/aishell3/ernie_sat/conf/default.yaml @@ -79,13 +79,13 @@ grad_clip: 1.0 ########################################################### # TRAINING SETTING # ########################################################### -max_epoch: 600 -num_snapshots: 5 +max_epoch: 1500 +num_snapshots: 50 ########################################################### # OTHER SETTING # ########################################################### -seed: 10086 +seed: 0 token_list: - diff --git a/examples/aishell3_vctk/ernie_sat/conf/default.yaml b/examples/aishell3_vctk/ernie_sat/conf/default.yaml index 745a5b84..abb69fcc 100644 --- a/examples/aishell3_vctk/ernie_sat/conf/default.yaml +++ b/examples/aishell3_vctk/ernie_sat/conf/default.yaml @@ -79,13 +79,13 @@ grad_clip: 1.0 ########################################################### # TRAINING SETTING # ########################################################### -max_epoch: 300 -num_snapshots: 5 +max_epoch: 700 +num_snapshots: 50 ########################################################### # OTHER SETTING # ########################################################### -seed: 10086 +seed: 0 token_list: - diff --git a/examples/vctk/ernie_sat/conf/default.yaml b/examples/vctk/ernie_sat/conf/default.yaml index b61c8170..672f937e 100644 --- a/examples/vctk/ernie_sat/conf/default.yaml +++ b/examples/vctk/ernie_sat/conf/default.yaml @@ -79,8 +79,8 @@ grad_clip: 1.0 ########################################################### # TRAINING SETTING # ########################################################### -max_epoch: 600 -num_snapshots: 5 +max_epoch: 1500 +num_snapshots: 50 ########################################################### # OTHER SETTING # diff --git a/paddlespeech/t2s/exps/ernie_sat/align.py b/paddlespeech/t2s/exps/ernie_sat/align.py new file mode 100755 index 00000000..529a8221 --- /dev/null +++ b/paddlespeech/t2s/exps/ernie_sat/align.py @@ -0,0 +1,386 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil +from pathlib import Path + +import librosa +import numpy as np +import pypinyin +from praatio import textgrid +from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name +from paddlespeech.t2s.exps.ernie_sat.utils import get_dict + + +DICT_EN = 'tools/aligner/cmudict-0.7b' +DICT_ZH = 'tools/aligner/simple.lexicon' +MODEL_DIR_EN = 'tools/aligner/vctk_model.zip' +MODEL_DIR_ZH = 'tools/aligner/aishell3_model.zip' +MFA_PATH = 'tools/montreal-forced-aligner/bin' +os.environ['PATH'] = MFA_PATH + '/:' + os.environ['PATH'] + +def _get_max_idx(dic): + return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1] + + +def _readtg(tg_path: str, lang: str='en', fs: int=24000, n_shift: int=300): + alignment = textgrid.openTextgrid(tg_path, includeEmptyIntervals=True) + phones = [] + ends = [] + words = [] + + for interval in alignment.tierDict['words'].entryList: + word = interval.label + if word: + words.append(word) + for interval in alignment.tierDict['phones'].entryList: + phone = interval.label + phones.append(phone) + ends.append(interval.end) + frame_pos = librosa.time_to_frames(ends, sr=fs, hop_length=n_shift) + durations = np.diff(frame_pos, prepend=0) + assert len(durations) == len(phones) + # merge '' and sp in the end + if phones[-1] == '' and len(phones) > 1 and phones[-2] == 'sp': + phones = phones[:-1] + durations[-2] += durations[-1] + durations = durations[:-1] + + # replace ' and 'sil' with 'sp' + phones = ['sp' if (phn == '' or phn == 'sil') else phn for phn in phones] + + if lang == 'en': + DICT = DICT_EN + + elif lang == 'zh': + DICT = DICT_ZH + + word2phns_dict = get_dict(DICT) + + phn2word_dict = [] + for word in words: + if lang == 'en': + word = word.upper() + phn2word_dict.append([word2phns_dict[word].split(), word]) + + non_sp_idx = 0 + word_idx = 0 + i = 0 + word2phns = {} + while i < len(phones): + phn = phones[i] + if phn == 'sp': + word2phns[str(word_idx) + '_sp'] = ['sp'] + i += 1 + else: + phns, word = phn2word_dict[non_sp_idx] + word2phns[str(word_idx) + '_' + word] = phns + non_sp_idx += 1 + i += len(phns) + word_idx += 1 + sum_phn = sum(len(word2phns[k]) for k in word2phns) + assert sum_phn == len(phones) + + results = '' + for (p, d) in zip(phones, durations): + results += p + ' ' + str(d) + ' ' + return results.strip(), word2phns + + +def alignment(wav_path: str, + text: str, + fs: int=24000, + lang='en', + n_shift: int=300): + wav_name = os.path.basename(wav_path) + utt = wav_name.split('.')[0] + # prepare data for MFA + tmp_name = get_tmp_name(text=text) + tmpbase = './tmp_dir/' + tmp_name + tmpbase = Path(tmpbase) + tmpbase.mkdir(parents=True, exist_ok=True) + print("tmp_name in alignment:",tmp_name) + + shutil.copyfile(wav_path, tmpbase / wav_name) + txt_name = utt + '.txt' + txt_path = tmpbase / txt_name + with open(txt_path, 'w') as wf: + wf.write(text + '\n') + # MFA + if lang == 'en': + DICT = DICT_EN + MODEL_DIR = MODEL_DIR_EN + + elif lang == 'zh': + DICT = DICT_ZH + MODEL_DIR = MODEL_DIR_ZH + else: + print('please input right lang!!') + + CMD = 'mfa_align' + ' ' + str( + tmpbase) + ' ' + DICT + ' ' + MODEL_DIR + ' ' + str(tmpbase) + os.system(CMD) + tg_path = str(tmpbase) + '/' + tmp_name + '/' + utt + '.TextGrid' + phn_dur, word2phns = _readtg(tg_path, lang=lang) + phn_dur = phn_dur.split() + phns = phn_dur[::2] + durs = phn_dur[1::2] + durs = [int(d) for d in durs] + assert len(phns) == len(durs) + return phns, durs, word2phns + + +def words2phns(text: str, lang='en'): + ''' + Args: + text (str): + input text. + eg: for that reason cover is impossible to be given. + lang (str): + 'en' or 'zh' + Returns: + List[str]: phones of input text. + eg: + ['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0', + 'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1', + 'G', 'IH1', 'V', 'AH0', 'N'] + + Dict(str, str): key - idx_word + value - phones + eg: + {'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], + '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], + '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'], + '6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']} + ''' + text = text.strip() + words = [] + for pun in [ + ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', + u'。', u':', u';', u'!', u'?', u'(', u')' + ]: + text = text.replace(pun, ' ') + for wrd in text.split(): + if (wrd[-1] == '-'): + wrd = wrd[:-1] + if (wrd[0] == "'"): + wrd = wrd[1:] + if wrd: + words.append(wrd) + if lang == 'en': + dictfile = DICT_EN + elif lang == 'zh': + dictfile = DICT_ZH + else: + print('please input right lang!!') + + word2phns_dict = get_dict(dictfile) + ds = word2phns_dict.keys() + phns = [] + wrd2phns = {} + for index, wrd in enumerate(words): + if lang == 'en': + wrd = wrd.upper() + if (wrd not in ds): + wrd2phns[str(index) + '_' + wrd] = 'spn' + phns.extend('spn') + else: + wrd2phns[str(index) + '_' + wrd] = word2phns_dict[wrd].split() + phns.extend(word2phns_dict[wrd].split()) + return phns, wrd2phns + + +def get_phns_spans(wav_path: str, + old_str: str='', + new_str: str='', + source_lang: str='en', + target_lang: str='en', + fs: int=24000, + n_shift: int=300): + is_append = (old_str == new_str[:len(old_str)]) + old_phns, mfa_start, mfa_end = [], [], [] + # source + lang = source_lang + phn, dur, w2p = alignment( + wav_path=wav_path, text=old_str, lang=lang, fs=fs, n_shift=n_shift) + + new_d_cumsum = np.pad(np.array(dur).cumsum(0), (1, 0), 'constant').tolist() + mfa_start = new_d_cumsum[:-1] + mfa_end = new_d_cumsum[1:] + old_phns = phn + + # target + if is_append and (source_lang != target_lang): + cross_lingual_clone = True + else: + cross_lingual_clone = False + + if cross_lingual_clone: + str_origin = new_str[:len(old_str)] + str_append = new_str[len(old_str):] + + if target_lang == 'zh': + phns_origin, origin_w2p = words2phns(str_origin, lang='en') + phns_append, append_w2p_tmp = words2phns(str_append, lang='zh') + elif target_lang == 'en': + # 原始句子 + phns_origin, origin_w2p = words2phns(str_origin, lang='zh') + # clone 句子 + phns_append, append_w2p_tmp = words2phns(str_append, lang='en') + else: + assert target_lang == 'zh' or target_lang == 'en', \ + 'cloning is not support for this language, please check it.' + + new_phns = phns_origin + phns_append + + append_w2p = {} + length = len(origin_w2p) + for key, value in append_w2p_tmp.items(): + idx, wrd = key.split('_') + append_w2p[str(int(idx) + length) + '_' + wrd] = value + new_w2p = origin_w2p.copy() + new_w2p.update(append_w2p) + + else: + if source_lang == target_lang: + new_phns, new_w2p = words2phns(new_str, lang=source_lang) + else: + assert source_lang == target_lang, \ + 'source language is not same with target language...' + + span_to_repl = [0, len(old_phns) - 1] + span_to_add = [0, len(new_phns) - 1] + left_idx = 0 + new_phns_left = [] + sp_count = 0 + # find the left different index + # 因为可能 align 时候的 words2phns 和直接 words2phns, 前者会有 sp? + for key in w2p.keys(): + idx, wrd = key.split('_') + if wrd == 'sp': + sp_count += 1 + new_phns_left.append('sp') + else: + idx = str(int(idx) - sp_count) + if idx + '_' + wrd in new_w2p: + # 是 new_str phn 序列的 index + left_idx += len(new_w2p[idx + '_' + wrd]) + # old phn 序列 + new_phns_left.extend(w2p[key]) + else: + span_to_repl[0] = len(new_phns_left) + span_to_add[0] = len(new_phns_left) + break + + # reverse w2p and new_w2p + right_idx = 0 + new_phns_right = [] + sp_count = 0 + w2p_max_idx = _get_max_idx(w2p) + new_w2p_max_idx = _get_max_idx(new_w2p) + new_phns_mid = [] + if is_append: + new_phns_right = [] + new_phns_mid = new_phns[left_idx:] + span_to_repl[0] = len(new_phns_left) + span_to_add[0] = len(new_phns_left) + span_to_add[1] = len(new_phns_left) + len(new_phns_mid) + span_to_repl[1] = len(old_phns) - len(new_phns_right) + # speech edit + else: + for key in list(w2p.keys())[::-1]: + idx, wrd = key.split('_') + if wrd == 'sp': + sp_count += 1 + new_phns_right = ['sp'] + new_phns_right + else: + idx = str(new_w2p_max_idx - (w2p_max_idx - int(idx) - sp_count)) + if idx + '_' + wrd in new_w2p: + right_idx -= len(new_w2p[idx + '_' + wrd]) + new_phns_right = w2p[key] + new_phns_right + else: + span_to_repl[1] = len(old_phns) - len(new_phns_right) + new_phns_mid = new_phns[left_idx:right_idx] + span_to_add[1] = len(new_phns_left) + len(new_phns_mid) + if len(new_phns_mid) == 0: + span_to_add[1] = min(span_to_add[1] + 1, len(new_phns)) + span_to_add[0] = max(0, span_to_add[0] - 1) + span_to_repl[0] = max(0, span_to_repl[0] - 1) + span_to_repl[1] = min(span_to_repl[1] + 1, + len(old_phns)) + break + new_phns = new_phns_left + new_phns_mid + new_phns_right + ''' + For that reason cover should not be given. + For that reason cover is impossible to be given. + span_to_repl: [17, 23] "should not" + span_to_add: [17, 30] "is impossible to" + ''' + outs = {} + outs['mfa_start'] = mfa_start + outs['mfa_end'] = mfa_end + outs['old_phns'] = old_phns + outs['new_phns'] = new_phns + outs['span_to_repl'] = span_to_repl + outs['span_to_add'] = span_to_add + + return outs + + +if __name__ == '__main__': + text = "For that reason cover should not be given." + phn, dur, word2phns = alignment("exp/p243_313.wav", text, lang='en') + print(phn, dur) + print(word2phns) + print("---------------------------------") + # 这里可以用我们的中文前端得到 pinyin 序列 + text_zh = "卡尔普陪外孙玩滑梯。" + text_zh = pypinyin.lazy_pinyin( + text_zh, + neutral_tone_with_five=True, + style=pypinyin.Style.TONE3, + tone_sandhi=True) + text_zh = " ".join(text_zh) + phn, dur, word2phns = alignment("exp/000001.wav", text_zh, lang='zh') + print(phn, dur) + print(word2phns) + print("---------------------------------") + phns, wrd2phns = words2phns(text, lang='en') + print("phns:", phns) + print("wrd2phns:", wrd2phns) + print("---------------------------------") + + phns, wrd2phns = words2phns(text_zh, lang='zh') + print("phns:", phns) + print("wrd2phns:", wrd2phns) + print("---------------------------------") + + outs = get_phns_spans( + wav_path="exp/p243_313.wav", + old_str="For that reason cover should not be given.", + new_str="for that reason cover is impossible to be given.") + + mfa_start = outs["mfa_start"] + mfa_end = outs["mfa_end"] + old_phns = outs["old_phns"] + new_phns = outs["new_phns"] + span_to_repl = outs["span_to_repl"] + span_to_add = outs["span_to_add"] + print("mfa_start:", mfa_start) + print("mfa_end:", mfa_end) + print("old_phns:", old_phns) + print("new_phns:", new_phns) + print("span_to_repl:", span_to_repl) + print("span_to_add:", span_to_add) + print("---------------------------------") diff --git a/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py b/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py new file mode 100644 index 00000000..95b07367 --- /dev/null +++ b/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py @@ -0,0 +1,346 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import librosa +import numpy as np +import soundfile as sf + +from paddlespeech.t2s.exps.ernie_sat.align import get_phns_spans +from paddlespeech.t2s.exps.ernie_sat.utils import eval_durs +from paddlespeech.t2s.exps.ernie_sat.utils import get_dur_adj_factor +from paddlespeech.t2s.exps.ernie_sat.utils import get_span_bdy +from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.exps.syn_utils import norm +from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name + + + + + + +def _p2id(self, phonemes: List[str]) -> np.ndarray: + # replace unk phone with sp + phonemes = [ + phn if phn in vocab_phones else "sp" for phn in phonemes + ] + phone_ids = [vocab_phones[item] for item in phonemes] + return np.array(phone_ids, np.int64) + + + +def prep_feats_with_dur(wav_path: str, + old_str: str='', + new_str: str='', + source_lang: str='en', + target_lang: str='en', + duration_adjust: bool=True, + fs: int=24000, + n_shift: int=300): + ''' + Returns: + np.ndarray: new wav, replace the part to be edited in original wav with 0 + List[str]: new phones + List[float]: mfa start of new wav + List[float]: mfa end of new wav + List[int]: masked mel boundary of original wav + List[int]: masked mel boundary of new wav + ''' + wav_org, _ = librosa.load(wav_path, sr=fs) + phns_spans_outs = get_phns_spans( + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + source_lang=source_lang, + target_lang=target_lang, + fs=fs, + n_shift=n_shift) + + mfa_start = phns_spans_outs["mfa_start"] + mfa_end = phns_spans_outs["mfa_end"] + old_phns = phns_spans_outs["old_phns"] + new_phns = phns_spans_outs["new_phns"] + span_to_repl = phns_spans_outs["span_to_repl"] + span_to_add = phns_spans_outs["span_to_add"] + + # 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替 + if target_lang in {'en', 'zh'}: + old_durs = eval_durs(old_phns, target_lang=source_lang) + else: + assert target_lang in {'en', 'zh'}, \ + "calculate duration_predict is not support for this language..." + + orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)] + + if duration_adjust: + d_factor = get_dur_adj_factor( + orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns) + d_factor = d_factor * 1.25 + else: + d_factor = 1 + + if target_lang in {'en', 'zh'}: + new_durs = eval_durs(new_phns, target_lang=target_lang) + else: + assert target_lang == "zh" or target_lang == "en", \ + "calculate duration_predict is not support for this language..." + + # duration 要是整数 + new_durs_adjusted = [int(np.ceil(d_factor * i)) for i in new_durs] + + new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]]) + old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]]) + dur_offset = new_span_dur_sum - old_span_dur_sum + new_mfa_start = mfa_start[:span_to_repl[0]] + new_mfa_end = mfa_end[:span_to_repl[0]] + + for dur in new_durs_adjusted[span_to_add[0]:span_to_add[1]]: + if len(new_mfa_end) == 0: + new_mfa_start.append(0) + new_mfa_end.append(dur) + else: + new_mfa_start.append(new_mfa_end[-1]) + new_mfa_end.append(new_mfa_end[-1] + dur) + + new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]] + new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]] + + # 3. get new wav + # 在原始句子后拼接 + if span_to_repl[0] >= len(mfa_start): + wav_left_idx = len(wav_org) + wav_right_idx = wav_left_idx + # 在原始句子中间替换 + else: + wav_left_idx = int(np.floor(mfa_start[span_to_repl[0]] * n_shift)) + wav_right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * n_shift)) + blank_wav = np.zeros( + (int(np.ceil(new_span_dur_sum * n_shift)), ), dtype=wav_org.dtype) + # 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定 + new_wav = np.concatenate( + [wav_org[:wav_left_idx], blank_wav, wav_org[wav_right_idx:]]) + + # 音频是正常遮住了 + sf.write(str("new_wav.wav"), new_wav, samplerate=fs) + + # 4. get old and new mel span to be mask + old_span_bdy = get_span_bdy( + mfa_start=mfa_start, mfa_end=mfa_end, span_to_repl=span_to_repl) + + new_span_bdy = get_span_bdy( + mfa_start=new_mfa_start, mfa_end=new_mfa_end, span_to_repl=span_to_add) + + # old_span_bdy, new_span_bdy 是帧级别的范围 + outs = {} + outs['new_wav'] = new_wav + outs['new_phns'] = new_phns + outs['new_mfa_start'] = new_mfa_start + outs['new_mfa_end'] = new_mfa_end + outs['old_span_bdy'] = old_span_bdy + outs['new_span_bdy'] = new_span_bdy + return outs + + + + +def prep_feats(wav_path: str, + old_str: str='', + new_str: str='', + source_lang: str='en', + target_lang: str='en', + duration_adjust: bool=True, + fs: int=24000, + n_shift: int=300): + + outs = prep_feats_with_dur( + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + source_lang=source_lang, + target_lang=target_lang, + duration_adjust=duration_adjust, + fs=fs, + n_shift=n_shift) + + wav_name = os.path.basename(wav_path) + utt_id = wav_name.split('.')[0] + + wav = outs['new_wav'] + phns = outs['new_phns'] + mfa_start = outs['new_mfa_start'] + mfa_end = outs['new_mfa_end'] + old_span_bdy = outs['old_span_bdy'] + new_span_bdy = outs['new_span_bdy'] + span_bdy = np.array(new_span_bdy) + + text = _p2id(phns) + mel = mel_extractor.get_log_mel_fbank(wav) + erniesat_mean, erniesat_std = np.load(erniesat_stat) + normed_mel = norm(mel, erniesat_mean, erniesat_std) + tmp_name = get_tmp_name(text=old_str) + tmpbase = './tmp_dir/' + tmp_name + tmpbase = Path(tmpbase) + tmpbase.mkdir(parents=True, exist_ok=True) + print("tmp_name in synthesize_e2e:",tmp_name) + + mel_path = tmpbase / 'mel.npy' + print("mel_path:",mel_path) + np.save(mel_path, logmel) + durations = [e - s for e, s in zip(mfa_end, mfa_start)] + + datum={ + "utt_id": utt_id, + "spk_id": 0, + "text": text, + "text_lengths": len(text), + "speech_lengths": 115, + "durations": durations, + "speech": mel_path, + "align_start": mfa_start, + "align_end": mfa_end, + "span_bdy": span_bdy + } + + batch = collate_fn([datum]) + print("batch:",batch) + + return batch, old_span_bdy, new_span_bdy + + +def decode_with_model(mlm_model: nn.Layer, + collate_fn, + wav_path: str, + old_str: str='', + new_str: str='', + source_lang: str='en', + target_lang: str='en', + use_teacher_forcing: bool=False, + duration_adjust: bool=True, + fs: int=24000, + n_shift: int=300, + token_list: List[str]=[]): + batch, old_span_bdy, new_span_bdy = prep_feats( + source_lang=source_lang, + target_lang=target_lang, + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + duration_adjust=duration_adjust, + fs=fs, + n_shift=n_shift, + token_list=token_list) + + + + feats = collate_fn(batch)[1] + + if 'text_masked_pos' in feats.keys(): + feats.pop('text_masked_pos') + + output = mlm_model.inference( + text=feats['text'], + speech=feats['speech'], + masked_pos=feats['masked_pos'], + speech_mask=feats['speech_mask'], + text_mask=feats['text_mask'], + speech_seg_pos=feats['speech_seg_pos'], + text_seg_pos=feats['text_seg_pos'], + span_bdy=new_span_bdy, + use_teacher_forcing=use_teacher_forcing) + + # 拼接音频 + output_feat = paddle.concat(x=output, axis=0) + wav_org, _ = librosa.load(wav_path, sr=fs) + return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length + + +if __name__ == '__main__': + fs = 24000 + n_shift = 300 + wav_path = "exp/p243_313.wav" + old_str = "For that reason cover should not be given." + # for edit + # new_str = "for that reason cover is impossible to be given." + # for synthesize + append_str = "do you love me i love you so much" + new_str = old_str + append_str + + ''' + outs = prep_feats_with_dur( + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + fs=fs, + n_shift=n_shift) + + new_wav = outs['new_wav'] + new_phns = outs['new_phns'] + new_mfa_start = outs['new_mfa_start'] + new_mfa_end = outs['new_mfa_end'] + old_span_bdy = outs['old_span_bdy'] + new_span_bdy = outs['new_span_bdy'] + + print("---------------------------------") + + print("new_wav:", new_wav) + print("new_phns:", new_phns) + print("new_mfa_start:", new_mfa_start) + print("new_mfa_end:", new_mfa_end) + print("old_span_bdy:", old_span_bdy) + print("new_span_bdy:", new_span_bdy) + print("---------------------------------") + ''' + + erniesat_config = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/local/default.yaml" + + with open(erniesat_config) as f: + erniesat_config = CfgNode(yaml.safe_load(f)) + + erniesat_stat = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/train/speech_stats.npy" + + # Extractor + mel_extractor = LogMelFBank( + sr=erniesat_config.fs, + n_fft=erniesat_config.n_fft, + hop_length=erniesat_config.n_shift, + win_length=erniesat_config.win_length, + window=erniesat_config.window, + n_mels=erniesat_config.n_mels, + fmin=erniesat_config.fmin, + fmax=erniesat_config.fmax) + + + + collate_fn = build_erniesat_collate_fn( + mlm_prob=erniesat_config.mlm_prob, + mean_phn_span=erniesat_config.mean_phn_span, + seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm', + text_masking=False) + + phones_dict='/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/phone_id_map.txt' + vocab_phones = {} + + with open(phones_dict, 'rt') as f: + phn_id = [line.strip().split() for line in f.readlines()] + for phn, id in phn_id: + vocab_phones[phn] = int(id) + + prep_feats(wav_path=wav_path, + old_str=old_str, + new_str=new_str, + fs=fs, + n_shift=n_shift) + + + diff --git a/paddlespeech/t2s/exps/ernie_sat/utils.py b/paddlespeech/t2s/exps/ernie_sat/utils.py new file mode 100644 index 00000000..9169efa3 --- /dev/null +++ b/paddlespeech/t2s/exps/ernie_sat/utils.py @@ -0,0 +1,216 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from typing import Dict +from typing import List +from typing import Union +import os + +import numpy as np +import paddle +import yaml +from yacs.config import CfgNode +import hashlib + + +from paddlespeech.t2s.exps.syn_utils import get_am_inference +from paddlespeech.t2s.exps.syn_utils import get_voc_inference + +def _get_user(): + return os.path.expanduser('~').split('/')[-1] + +def str2md5(string): + md5_val = hashlib.md5(string.encode('utf8')).hexdigest() + return md5_val + +def get_tmp_name(text:str): + return _get_user() + '_' + str(os.getpid()) + '_' + str2md5(text) + +def get_dict(dictfile: str): + word2phns_dict = {} + with open(dictfile, 'r') as fid: + for line in fid: + line_lst = line.split() + word, phn_lst = line_lst[0], line.split()[1:] + if word not in word2phns_dict.keys(): + word2phns_dict[word] = ' '.join(phn_lst) + return word2phns_dict + + +# 获取需要被 mask 的 mel 帧的范围 +def get_span_bdy(mfa_start: List[float], + mfa_end: List[float], + span_to_repl: List[List[int]]): + if span_to_repl[0] >= len(mfa_start): + span_bdy = [mfa_end[-1], mfa_end[-1]] + else: + span_bdy = [mfa_start[span_to_repl[0]], mfa_end[span_to_repl[1] - 1]] + return span_bdy + + +# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同 +# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放 +def get_dur_adj_factor(orig_dur: List[int], + pred_dur: List[int], + phns: List[str]): + length = 0 + factor_list = [] + for orig, pred, phn in zip(orig_dur, pred_dur, phns): + if pred == 0 or phn == 'sp': + continue + else: + factor_list.append(orig / pred) + factor_list = np.array(factor_list) + factor_list.sort() + if len(factor_list) < 5: + return 1 + length = 2 + avg = np.average(factor_list[length:-length]) + return avg + + +def read_2col_text(path: Union[Path, str]) -> Dict[str, str]: + """Read a text file having 2 column as dict object. + + Examples: + wav.scp: + key1 /some/path/a.wav + key2 /some/path/b.wav + + >>> read_2col_text('wav.scp') + {'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'} + + """ + + data = {} + with Path(path).open("r", encoding="utf-8") as f: + for linenum, line in enumerate(f, 1): + sps = line.rstrip().split(maxsplit=1) + if len(sps) == 1: + k, v = sps[0], "" + else: + k, v = sps + if k in data: + raise RuntimeError(f"{k} is duplicated ({path}:{linenum})") + data[k] = v + return data + + +def load_num_sequence_text(path: Union[Path, str], loader_type: str="csv_int" + ) -> Dict[str, List[Union[float, int]]]: + """Read a text file indicating sequences of number + + Examples: + key1 1 2 3 + key2 34 5 6 + + >>> d = load_num_sequence_text('text') + >>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3])) + """ + if loader_type == "text_int": + delimiter = " " + dtype = int + elif loader_type == "text_float": + delimiter = " " + dtype = float + elif loader_type == "csv_int": + delimiter = "," + dtype = int + elif loader_type == "csv_float": + delimiter = "," + dtype = float + else: + raise ValueError(f"Not supported loader_type={loader_type}") + + # path looks like: + # utta 1,0 + # uttb 3,4,5 + # -> return {'utta': np.ndarray([1, 0]), + # 'uttb': np.ndarray([3, 4, 5])} + d = read_2column_text(path) + # Using for-loop instead of dict-comprehension for debuggability + retval = {} + for k, v in d.items(): + try: + retval[k] = [dtype(i) for i in v.split(delimiter)] + except TypeError: + print(f'Error happened with path="{path}", id="{k}", value="{v}"') + raise + return retval + + +def is_chinese(ch): + if u'\u4e00' <= ch <= u'\u9fff': + return True + else: + return False + + +def get_voc_out(mel): + # vocoder + args = parse_args() + with open(args.voc_config) as f: + voc_config = CfgNode(yaml.safe_load(f)) + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) + + with paddle.no_grad(): + wav = voc_inference(mel) + return np.squeeze(wav) + + +def eval_durs(phns, target_lang: str='zh', fs: int=24000, n_shift: int=300): + + if target_lang == 'en': + am = "fastspeech2_ljspeech" + am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml" + am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz" + am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy" + phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt" + + elif target_lang == 'zh': + am = "fastspeech2_csmsc" + am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml" + am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz" + am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy" + phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" + + # Init body. + with open(am_config) as f: + am_config = CfgNode(yaml.safe_load(f)) + + am_inference, am = get_am_inference( + am=am, + am_config=am_config, + am_ckpt=am_ckpt, + am_stat=am_stat, + phones_dict=phones_dict, + return_am=True) + + vocab_phones = {} + with open(phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + for tone, id in phn_id: + vocab_phones[tone] = int(id) + vocab_size = len(vocab_phones) + phonemes = [phn if phn in vocab_phones else "sp" for phn in phns] + + phone_ids = [vocab_phones[item] for item in phonemes] + phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64)) + _, d_outs, _, _ = am.inference(phone_ids) + d_outs = d_outs.tolist() + return d_outs