From b5f376e63b77a009b383c1720451e857b45e8147 Mon Sep 17 00:00:00 2001 From: iftaken Date: Thu, 15 Sep 2022 10:19:06 +0800 Subject: [PATCH] add voice clone function --- demos/speech_web/.gitignore | 1 + .../speech_web/speech_server/requirements.txt | 14 +- demos/speech_web/speech_server/src/align.py | 410 +++++++++++++ .../speech_web/speech_server/src/ernie_sat.py | 247 ++++++++ .../speech_server/src/ernie_sat_tool.py | 542 +++++++++++++++++ .../speech_web/speech_server/src/finetune.py | 142 +++++ .../speech_server/src/ft/check_oov.py | 125 ++++ .../speech_server/src/ft/extract.py | 287 +++++++++ .../speech_server/src/ft/finetune_tool.py | 316 ++++++++++ .../speech_server/src/ft/label_process.py | 63 ++ .../speech_server/src/ft/prepare_env.py | 35 ++ .../speech_server/src/ft/synthesize.py | 262 +++++++++ .../speech_server/src/ge2e_clone.py | 130 +++++ .../speech_server/src/tdnn_clone.py | 113 ++++ demos/speech_web/speech_server/vc.py | 543 ++++++++++++++++++ demos/speech_web/web_client/package.json | 1 + demos/speech_web/web_client/src/api/API.js | 20 + demos/speech_web/web_client/src/api/ApiVC.js | 88 +++ .../src/components/Content/Header/Header.vue | 2 +- .../src/components/Content/Header/style.less | 1 + .../web_client/src/components/Experience.vue | 13 + .../SubMenu/ENIRE_SAT/ENIRE_SAT.vue | 487 ++++++++++++++++ .../components/SubMenu/FineTune/FineTune.vue | 427 ++++++++++++++ .../SubMenu/VoiceClone/VoiceClone.vue | 379 ++++++++++++ demos/speech_web/web_client/src/main.js | 4 + demos/speech_web/web_client/yarn.lock | 5 + paddlespeech/__init__.py | 12 + 27 files changed, 4660 insertions(+), 9 deletions(-) create mode 100644 demos/speech_web/speech_server/src/align.py create mode 100644 demos/speech_web/speech_server/src/ernie_sat.py create mode 100644 demos/speech_web/speech_server/src/ernie_sat_tool.py create mode 100644 demos/speech_web/speech_server/src/finetune.py create mode 100644 demos/speech_web/speech_server/src/ft/check_oov.py create mode 100644 demos/speech_web/speech_server/src/ft/extract.py create mode 100644 demos/speech_web/speech_server/src/ft/finetune_tool.py create mode 100644 demos/speech_web/speech_server/src/ft/label_process.py create mode 100644 demos/speech_web/speech_server/src/ft/prepare_env.py create mode 100644 demos/speech_web/speech_server/src/ft/synthesize.py create mode 100644 demos/speech_web/speech_server/src/ge2e_clone.py create mode 100644 demos/speech_web/speech_server/src/tdnn_clone.py create mode 100644 demos/speech_web/speech_server/vc.py create mode 100644 demos/speech_web/web_client/src/api/ApiVC.js create mode 100644 demos/speech_web/web_client/src/components/SubMenu/ENIRE_SAT/ENIRE_SAT.vue create mode 100644 demos/speech_web/web_client/src/components/SubMenu/FineTune/FineTune.vue create mode 100644 demos/speech_web/web_client/src/components/SubMenu/VoiceClone/VoiceClone.vue diff --git a/demos/speech_web/.gitignore b/demos/speech_web/.gitignore index 54418e605..72f5c52d6 100644 --- a/demos/speech_web/.gitignore +++ b/demos/speech_web/.gitignore @@ -13,4 +13,5 @@ *.pdmodel */source/* */PaddleSpeech/* +*/tmp*/* diff --git a/demos/speech_web/speech_server/requirements.txt b/demos/speech_web/speech_server/requirements.txt index 607f0d4d0..c0880df8b 100644 --- a/demos/speech_web/speech_server/requirements.txt +++ b/demos/speech_web/speech_server/requirements.txt @@ -1,13 +1,11 @@ aiofiles faiss-cpu -fastapi -librosa -numpy -paddlenlp -paddlepaddle -paddlespeech pydantic -python-multipartscikit_learn -SoundFile +python-multipart +scikit_learn starlette uvicorn +numpy==1.20.0 +librosa==0.8.1 +praatio==5.0.0 +pyworld==0.3.0 \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/align.py b/demos/speech_web/speech_server/src/align.py new file mode 100644 index 000000000..3c6a172d2 --- /dev/null +++ b/demos/speech_web/speech_server/src/align.py @@ -0,0 +1,410 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil +from pathlib import Path + +import librosa +import numpy as np +import pypinyin +from praatio import textgrid + +from paddlespeech.t2s.exps.ernie_sat.utils import get_dict +from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name + +DICT_EN = 'source/tools/aligner/cmudict-0.7b' +DICT_EN_v2 = 'source/tools/aligner/cmudict-0.7b.dict' +DICT_ZH = 'source/tools/aligner/simple.lexicon' +DICT_ZH_v2 = 'source/tools/aligner/simple.dict' +MODEL_DIR_EN = 'source/tools/aligner/vctk_model.zip' +MODEL_DIR_ZH = 'source/tools/aligner/aishell3_model.zip' +MFA_PATH = 'source/tools/montreal-forced-aligner/bin' +os.environ['PATH'] = os.path.realpath(MFA_PATH) + '/:' + os.environ['PATH'] + + +def _get_max_idx(dic): + return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1] + + +def _readtg(tg_path: str, lang: str='en', fs: int=24000, n_shift: int=300): + alignment = textgrid.openTextgrid(tg_path, includeEmptyIntervals=True) + phones = [] + ends = [] + words = [] + + for interval in alignment.tierDict['words'].entryList: + word = interval.label + if word: + words.append(word) + for interval in alignment.tierDict['phones'].entryList: + phone = interval.label + phones.append(phone) + ends.append(interval.end) + frame_pos = librosa.time_to_frames(ends, sr=fs, hop_length=n_shift) + durations = np.diff(frame_pos, prepend=0) + assert len(durations) == len(phones) + # merge '' and sp in the end + if phones[-1] == '' and len(phones) > 1 and phones[-2] == 'sp': + phones = phones[:-1] + durations[-2] += durations[-1] + durations = durations[:-1] + + # replace ' and 'sil' with 'sp' + phones = ['sp' if (phn == '' or phn == 'sil') else phn for phn in phones] + + if lang == 'en': + DICT = DICT_EN + + elif lang == 'zh': + DICT = DICT_ZH + + word2phns_dict = get_dict(DICT) + + phn2word_dict = [] + for word in words: + if lang == 'en': + word = word.upper() + phn2word_dict.append([word2phns_dict[word].split(), word]) + + non_sp_idx = 0 + word_idx = 0 + i = 0 + word2phns = {} + while i < len(phones): + phn = phones[i] + if phn == 'sp': + word2phns[str(word_idx) + '_sp'] = ['sp'] + i += 1 + else: + phns, word = phn2word_dict[non_sp_idx] + word2phns[str(word_idx) + '_' + word] = phns + non_sp_idx += 1 + i += len(phns) + word_idx += 1 + sum_phn = sum(len(word2phns[k]) for k in word2phns) + assert sum_phn == len(phones) + + results = '' + for (p, d) in zip(phones, durations): + results += p + ' ' + str(d) + ' ' + return results.strip(), word2phns + + +def alignment(wav_path: str, + text: str, + fs: int=24000, + lang='en', + n_shift: int=300, + mfa_version='v1'): + wav_name = os.path.basename(wav_path) + utt = wav_name.split('.')[0] + # prepare data for MFA + tmp_name = get_tmp_name(text=text) + tmpbase = './tmp_dir/' + tmp_name + tmpbase = Path(tmpbase) + tmpbase.mkdir(parents=True, exist_ok=True) + print("tmp_name in alignment:", tmp_name) + + shutil.copyfile(wav_path, tmpbase / wav_name) + txt_name = utt + '.txt' + txt_path = tmpbase / txt_name + with open(txt_path, 'w') as wf: + wf.write(text + '\n') + # MFA + if mfa_version == 'v1': + if lang == 'en': + DICT = DICT_EN + MODEL_DIR = MODEL_DIR_EN + + elif lang == 'zh': + DICT = DICT_ZH + MODEL_DIR = MODEL_DIR_ZH + else: + print('please input right lang!!') + exit(-1) + CMD = 'mfa_align' + ' ' + str( + tmpbase) + ' ' + DICT + ' ' + MODEL_DIR + ' ' + str(tmpbase) + os.system(CMD) + tg_path = str(tmpbase) + '/' + tmp_name + '/' + utt + '.TextGrid' + else: + # mfa 2.0 + mfa_out = os.path.join(tmpbase, "mfa_out") + os.makedirs(mfa_out, exist_ok=True) + if lang == 'en': + DICT = DICT_EN_v2 + MODEL_DIR = MODEL_DIR_EN + + elif lang == 'zh': + DICT = DICT_ZH_v2 + MODEL_DIR = MODEL_DIR_ZH + else: + print('please input right lang!!') + exit(-1) + CMD = 'mfa align' + ' ' + str( + tmpbase) + ' ' + DICT + ' ' + MODEL_DIR + ' ' + str(mfa_out) + os.system(CMD) + tg_path = str(mfa_out) + '/' + utt + '.TextGrid' + phn_dur, word2phns = _readtg(tg_path, lang=lang) + phn_dur = phn_dur.split() + phns = phn_dur[::2] + durs = phn_dur[1::2] + durs = [int(d) for d in durs] + assert len(phns) == len(durs) + return phns, durs, word2phns + + +def words2phns(text: str, lang='en'): + ''' + Args: + text (str): + input text. + eg: for that reason cover is impossible to be given. + lang (str): + 'en' or 'zh' + Returns: + List[str]: phones of input text. + eg: + ['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0', + 'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1', + 'G', 'IH1', 'V', 'AH0', 'N'] + + Dict(str, str): key - idx_word + value - phones + eg: + {'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], + '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], + '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'], + '6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']} + ''' + text = text.strip() + words = [] + for pun in [ + ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', + u'。', u':', u';', u'!', u'?', u'(', u')' + ]: + text = text.replace(pun, ' ') + for wrd in text.split(): + if (wrd[-1] == '-'): + wrd = wrd[:-1] + if (wrd[0] == "'"): + wrd = wrd[1:] + if wrd: + words.append(wrd) + if lang == 'en': + dictfile = DICT_EN + elif lang == 'zh': + dictfile = DICT_ZH + else: + print('please input right lang!!') + + word2phns_dict = get_dict(dictfile) + ds = word2phns_dict.keys() + phns = [] + wrd2phns = {} + for index, wrd in enumerate(words): + if lang == 'en': + wrd = wrd.upper() + if (wrd not in ds): + wrd2phns[str(index) + '_' + wrd] = 'spn' + phns.extend(['spn']) + else: + wrd2phns[str(index) + '_' + wrd] = word2phns_dict[wrd].split() + phns.extend(word2phns_dict[wrd].split()) + return phns, wrd2phns + + +def get_phns_spans(wav_path: str, + old_str: str='', + new_str: str='', + source_lang: str='en', + target_lang: str='en', + fs: int=24000, + n_shift: int=300, + mfa_version='v1'): + is_append = (old_str == new_str[:len(old_str)]) + old_phns, mfa_start, mfa_end = [], [], [] + # source + lang = source_lang + phn, dur, w2p = alignment( + wav_path=wav_path, text=old_str, lang=lang, fs=fs, n_shift=n_shift, mfa_version=mfa_version) + + new_d_cumsum = np.pad(np.array(dur).cumsum(0), (1, 0), 'constant').tolist() + mfa_start = new_d_cumsum[:-1] + mfa_end = new_d_cumsum[1:] + old_phns = phn + + # target + if is_append and (source_lang != target_lang): + cross_lingual_clone = True + else: + cross_lingual_clone = False + + if cross_lingual_clone: + str_origin = new_str[:len(old_str)] + str_append = new_str[len(old_str):] + + if target_lang == 'zh': + phns_origin, origin_w2p = words2phns(str_origin, lang='en') + phns_append, append_w2p_tmp = words2phns(str_append, lang='zh') + elif target_lang == 'en': + # 原始句子 + phns_origin, origin_w2p = words2phns(str_origin, lang='zh') + # clone 句子 + phns_append, append_w2p_tmp = words2phns(str_append, lang='en') + else: + assert target_lang == 'zh' or target_lang == 'en', \ + 'cloning is not support for this language, please check it.' + + new_phns = phns_origin + phns_append + + append_w2p = {} + length = len(origin_w2p) + for key, value in append_w2p_tmp.items(): + idx, wrd = key.split('_') + append_w2p[str(int(idx) + length) + '_' + wrd] = value + new_w2p = origin_w2p.copy() + new_w2p.update(append_w2p) + + else: + if source_lang == target_lang: + new_phns, new_w2p = words2phns(new_str, lang=source_lang) + else: + assert source_lang == target_lang, \ + 'source language is not same with target language...' + + span_to_repl = [0, len(old_phns) - 1] + span_to_add = [0, len(new_phns) - 1] + left_idx = 0 + new_phns_left = [] + sp_count = 0 + # find the left different index + # 因为可能 align 时候的 words2phns 和直接 words2phns, 前者会有 sp? + for key in w2p.keys(): + idx, wrd = key.split('_') + if wrd == 'sp': + sp_count += 1 + new_phns_left.append('sp') + else: + idx = str(int(idx) - sp_count) + if idx + '_' + wrd in new_w2p: + # 是 new_str phn 序列的 index + left_idx += len(new_w2p[idx + '_' + wrd]) + # old phn 序列 + new_phns_left.extend(w2p[key]) + else: + span_to_repl[0] = len(new_phns_left) + span_to_add[0] = len(new_phns_left) + break + + # reverse w2p and new_w2p + right_idx = 0 + new_phns_right = [] + sp_count = 0 + w2p_max_idx = _get_max_idx(w2p) + new_w2p_max_idx = _get_max_idx(new_w2p) + new_phns_mid = [] + if is_append: + new_phns_right = [] + new_phns_mid = new_phns[left_idx:] + span_to_repl[0] = len(new_phns_left) + span_to_add[0] = len(new_phns_left) + span_to_add[1] = len(new_phns_left) + len(new_phns_mid) + span_to_repl[1] = len(old_phns) - len(new_phns_right) + # speech edit + else: + for key in list(w2p.keys())[::-1]: + idx, wrd = key.split('_') + if wrd == 'sp': + sp_count += 1 + new_phns_right = ['sp'] + new_phns_right + else: + idx = str(new_w2p_max_idx - (w2p_max_idx - int(idx) - sp_count)) + if idx + '_' + wrd in new_w2p: + right_idx -= len(new_w2p[idx + '_' + wrd]) + new_phns_right = w2p[key] + new_phns_right + else: + span_to_repl[1] = len(old_phns) - len(new_phns_right) + new_phns_mid = new_phns[left_idx:right_idx] + span_to_add[1] = len(new_phns_left) + len(new_phns_mid) + if len(new_phns_mid) == 0: + span_to_add[1] = min(span_to_add[1] + 1, len(new_phns)) + span_to_add[0] = max(0, span_to_add[0] - 1) + span_to_repl[0] = max(0, span_to_repl[0] - 1) + span_to_repl[1] = min(span_to_repl[1] + 1, + len(old_phns)) + break + new_phns = new_phns_left + new_phns_mid + new_phns_right + ''' + For that reason cover should not be given. + For that reason cover is impossible to be given. + span_to_repl: [17, 23] "should not" + span_to_add: [17, 30] "is impossible to" + ''' + outs = {} + outs['mfa_start'] = mfa_start + outs['mfa_end'] = mfa_end + outs['old_phns'] = old_phns + outs['new_phns'] = new_phns + outs['span_to_repl'] = span_to_repl + outs['span_to_add'] = span_to_add + + return outs + + +if __name__ == '__main__': + text = "For that reason cover should not be given." + phn, dur, word2phns = alignment("source/p243_313.wav", text, lang='en') + print(phn, dur) + print(word2phns) + print("---------------------------------") + # 这里可以用我们的中文前端得到 pinyin 序列 + text_zh = "卡尔普陪外孙玩滑梯。" + text_zh = pypinyin.lazy_pinyin( + text_zh, + neutral_tone_with_five=True, + style=pypinyin.Style.TONE3, + tone_sandhi=True) + text_zh = " ".join(text_zh) + phn, dur, word2phns = alignment("source/000001.wav", text_zh, lang='zh') + print(phn, dur) + print(word2phns) + print("---------------------------------") + phns, wrd2phns = words2phns(text, lang='en') + print("phns:", phns) + print("wrd2phns:", wrd2phns) + print("---------------------------------") + + phns, wrd2phns = words2phns(text_zh, lang='zh') + print("phns:", phns) + print("wrd2phns:", wrd2phns) + print("---------------------------------") + + outs = get_phns_spans( + wav_path="source/p243_313.wav", + old_str="For that reason cover should not be given.", + new_str="for that reason cover is impossible to be given.") + + mfa_start = outs["mfa_start"] + mfa_end = outs["mfa_end"] + old_phns = outs["old_phns"] + new_phns = outs["new_phns"] + span_to_repl = outs["span_to_repl"] + span_to_add = outs["span_to_add"] + print("mfa_start:", mfa_start) + print("mfa_end:", mfa_end) + print("old_phns:", old_phns) + print("new_phns:", new_phns) + print("span_to_repl:", span_to_repl) + print("span_to_add:", span_to_add) + print("---------------------------------") diff --git a/demos/speech_web/speech_server/src/ernie_sat.py b/demos/speech_web/speech_server/src/ernie_sat.py new file mode 100644 index 000000000..0bace1581 --- /dev/null +++ b/demos/speech_web/speech_server/src/ernie_sat.py @@ -0,0 +1,247 @@ +from .ernie_sat_tool import ernie_sat_web +import os + +class SAT: + def __init__(self, mfa_version='v1'): + self.mfa_version = mfa_version + + def zh_synthesize_edit(self, + old_str:str, + new_str:str, + input_name:os.PathLike, + output_name:os.PathLike, + task_name:str="synthesize" + ): + + if task_name not in ['synthesize', 'edit']: + print("task name only in ['edit', 'synthesize']") + return None + + # erniesat model + erniesat_config = "source/model/erniesat_aishell3_ckpt_1.2.0/default.yaml" + erniesat_ckpt = "source/model/erniesat_aishell3_ckpt_1.2.0/snapshot_iter_289500.pdz" + erniesat_stat = "source/model/erniesat_aishell3_ckpt_1.2.0/speech_stats.npy" + phones_dict = "source/model/erniesat_aishell3_ckpt_1.2.0/phone_id_map.txt" + duration_adjust = True + # vocoder + voc = "hifigan_aishell3" + voc_config = "source/model/hifigan_aishell3_ckpt_0.2.0/default.yaml" + voc_ckpt = "source/model/hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz" + voc_stat = "source/model/hifigan_aishell3_ckpt_0.2.0/feats_stats.npy" + + source_lang = "zh" + target_lang = "zh" + wav_path = input_name + output_name = output_name + + output_name = ernie_sat_web(erniesat_config, + old_str, + new_str, + source_lang, + target_lang, + task_name, + erniesat_ckpt, + erniesat_stat, + phones_dict, + voc_config, + voc, + voc_ckpt, + voc_stat, + duration_adjust, + wav_path, + output_name, + mfa_version=self.mfa_version + ) + return output_name + + + def crossclone(self, + old_str:str, + new_str:str,input_name:os.PathLike, + output_name:os.PathLike, + source_lang:str, + target_lang:str, + ): + # erniesat model + erniesat_config = "source/model/erniesat_aishell3_vctk_ckpt_1.2.0/default.yaml" + erniesat_ckpt = "source/model/erniesat_aishell3_vctk_ckpt_1.2.0/snapshot_iter_489000.pdz" + erniesat_stat = "source/model/erniesat_aishell3_vctk_ckpt_1.2.0/speech_stats.npy" + phones_dict = "source/model/erniesat_aishell3_vctk_ckpt_1.2.0/phone_id_map.txt" + duration_adjust = True + # vocoder + voc = "hifigan_aishell3" + voc_config = "source/model/hifigan_aishell3_ckpt_0.2.0/default.yaml" + voc_ckpt = "source/model/hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz" + voc_stat = "source/model/hifigan_aishell3_ckpt_0.2.0/feats_stats.npy" + + task_name = 'synthesize' + wav_path = input_name + output_name = output_name + + output_name = ernie_sat_web(erniesat_config, + old_str, + new_str, + source_lang, + target_lang, + task_name, + erniesat_ckpt, + erniesat_stat, + phones_dict, + voc_config, + voc, + voc_ckpt, + voc_stat, + duration_adjust, + wav_path, + output_name, + mfa_version=self.mfa_version + ) + return output_name + + def en_synthesize_edit(self, + old_str:str, + new_str:str,input_name:os.PathLike, + output_name:os.PathLike, + task_name:str="synthesize"): + # erniesat model + erniesat_config = "source/model/erniesat_vctk_ckpt_1.2.0/default.yaml" + erniesat_ckpt = "source/model/erniesat_vctk_ckpt_1.2.0/snapshot_iter_199500.pdz" + erniesat_stat = "source/model/erniesat_vctk_ckpt_1.2.0/speech_stats.npy" + phones_dict = "source/model/erniesat_vctk_ckpt_1.2.0/phone_id_map.txt" + duration_adjust = True + # vocoder + voc = "hifigan_aishell3" + voc_config = "source/model/hifigan_vctk_ckpt_0.2.0/default.yaml" + voc_ckpt = "source/model/hifigan_vctk_ckpt_0.2.0/snapshot_iter_2500000.pdz" + voc_stat = "source/model/hifigan_vctk_ckpt_0.2.0/feats_stats.npy" + + source_lang = "en" + target_lang = "en" + wav_path = input_name + output_name = output_name + + output_name = ernie_sat_web(erniesat_config, + old_str, + new_str, + source_lang, + target_lang, + task_name, + erniesat_ckpt, + erniesat_stat, + phones_dict, + voc_config, + voc, + voc_ckpt, + voc_stat, + duration_adjust, + wav_path, + output_name, + mfa_version=self.mfa_version + ) + return output_name + + + + +if __name__ == '__main__': + + sat = SAT(mfa_version='v2') + # 中文语音克隆 + print("######## 中文语音克隆 #######") + old_str = "请播放歌曲小苹果。" + new_str = "歌曲真好听。" + input_name = "source/wav/SAT/upload/SSB03540307.wav" + output_name = "source/wav/SAT/out/sat_syn.wav" + output_name = os.path.realpath(output_name) + sat.zh_synthesize_edit( + old_str=old_str, + new_str=new_str, + input_name=input_name, + output_name=output_name, + task_name="synthesize" + ) + + # 中文语音编辑 + print("######## 中文语音编辑 #######") + old_str = "今天天气很好" + new_str = "今天心情很好" + input_name = "source/wav/SAT/upload/SSB03540428.wav" + output_name = "source/wav/SAT/out/sat_edit.wav" + output_name = os.path.realpath(output_name) + print(os.path.realpath(output_name)) + sat.zh_synthesize_edit( + old_str=old_str, + new_str=new_str, + input_name=input_name, + output_name=output_name, + task_name="edit" + ) + + # 中文跨语言克隆 + print("######## 中文 跨语言音色克隆 #######") + old_str = "请播放歌曲小苹果。" + new_str = "Thank you very mych! what can i do for you" + source_lang='zh' + target_lang='en' + input_name = "source/wav/SAT/upload/SSB03540307.wav" + output_name = "source/wav/SAT/out/sat_cross_zh2en.wav" + output_name = os.path.realpath(output_name) + print(os.path.realpath(output_name)) + sat.crossclone( + old_str=old_str, + new_str=new_str, + input_name=input_name, + output_name=output_name, + source_lang=source_lang, + target_lang=target_lang + ) + + # 英文跨语言克隆 + print("######## 英文 跨语言音色克隆 #######") + old_str = "For that reason cover should not be given." + new_str = "今天天气很好" + source_lang='en' + target_lang='zh' + input_name = "source/wav/SAT/upload/p243_313.wav" + output_name = "source/wav/SAT/out/sat_cross_en2zh.wav" + output_name = os.path.realpath(output_name) + print(os.path.realpath(output_name)) + sat.crossclone( + old_str=old_str, + new_str=new_str, + input_name=input_name, + output_name=output_name, + source_lang=source_lang, + target_lang=target_lang + ) + + # 英文语音克隆 + print("######## 英文音色克隆 #######") + old_str = "For that reason cover should not be given." + new_str = "I love you very much do you love me" + input_name = "source/wav/SAT/upload/p243_313.wav" + output_name = "source/wav/SAT/out/sat_syn_en.wav" + output_name = os.path.realpath(output_name) + sat.en_synthesize_edit( + old_str=old_str, + new_str=new_str, + input_name=input_name, + output_name=output_name, + task_name="synthesize" + ) + + # 英文语音编辑 + print("######## 英文语音编辑 #######") + old_str = "For that reason cover should not be given." + new_str = "For that reason cover is not impossible to be given." + input_name = "source/wav/SAT/upload/p243_313.wav" + output_name = "source/wav/SAT/out/sat_edit_en.wav" + output_name = os.path.realpath(output_name) + sat.en_synthesize_edit( + old_str=old_str, + new_str=new_str, + input_name=input_name, + output_name=output_name, + task_name="edit" + ) + diff --git a/demos/speech_web/speech_server/src/ernie_sat_tool.py b/demos/speech_web/speech_server/src/ernie_sat_tool.py new file mode 100644 index 000000000..e5c469a59 --- /dev/null +++ b/demos/speech_web/speech_server/src/ernie_sat_tool.py @@ -0,0 +1,542 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from pathlib import Path +from typing import List +from unittest import main + +import librosa +import numpy as np +import paddle +import pypinyin +import soundfile as sf +import yaml +from pypinyin_dict.phrase_pinyin_data import large_pinyin +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +# from paddlespeech.t2s.exps.ernie_sat.align import get_phns_spans +from paddlespeech.t2s.exps.ernie_sat.utils import get_dur_adj_factor +from paddlespeech.t2s.exps.ernie_sat.utils import get_span_bdy +from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name +from paddlespeech.t2s.exps.syn_utils import get_am_inference +from paddlespeech.t2s.exps.syn_utils import get_voc_inference +from paddlespeech.t2s.exps.syn_utils import norm +from paddlespeech.t2s.utils import str2bool +large_pinyin.load() + +from .align import get_phns_spans + +def eval_durs(phns, target_lang: str='zh', fs: int=24000, n_shift: int=300): + + if target_lang == 'en': + am = "fastspeech2_ljspeech" + am_config = "source/model/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml" + am_ckpt = "source/model/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz" + am_stat = "source/model/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy" + phones_dict = "source/model/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt" + + elif target_lang == 'zh': + am = "fastspeech2_csmsc" + am_config = "source/model/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml" + am_ckpt = "source/model/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz" + am_stat = "source/model/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy" + phones_dict = "source/model/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" + + # Init body. + with open(am_config) as f: + am_config = CfgNode(yaml.safe_load(f)) + + am_inference, am = get_am_inference( + am=am, + am_config=am_config, + am_ckpt=am_ckpt, + am_stat=am_stat, + phones_dict=phones_dict, + return_am=True) + + vocab_phones = {} + with open(phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + for tone, id in phn_id: + vocab_phones[tone] = int(id) + vocab_size = len(vocab_phones) + phonemes = [phn if phn in vocab_phones else "sp" for phn in phns] + + phone_ids = [vocab_phones[item] for item in phonemes] + phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64)) + _, d_outs, _, _ = am.inference(phone_ids) + d_outs = d_outs.tolist() + return d_outs + + + +def _p2id(phonemes: List[str], vocab_phones) -> np.ndarray: + # replace unk phone with sp + phonemes = [phn if phn in vocab_phones else "sp" for phn in phonemes] + phone_ids = [vocab_phones[item] for item in phonemes] + return np.array(phone_ids, np.int64) + + +def prep_feats_with_dur(wav_path: str, + old_str: str='', + new_str: str='', + source_lang: str='en', + target_lang: str='en', + duration_adjust: bool=True, + fs: int=24000, + n_shift: int=300, + mfa_version='v1'): + ''' + Returns: + np.ndarray: new wav, replace the part to be edited in original wav with 0 + List[str]: new phones + List[float]: mfa start of new wav + List[float]: mfa end of new wav + List[int]: masked mel boundary of original wav + List[int]: masked mel boundary of new wav + ''' + wav_org, _ = librosa.load(wav_path, sr=fs) + phns_spans_outs = get_phns_spans( + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + source_lang=source_lang, + target_lang=target_lang, + fs=fs, + n_shift=n_shift, + mfa_version=mfa_version) + + mfa_start = phns_spans_outs['mfa_start'] + mfa_end = phns_spans_outs['mfa_end'] + old_phns = phns_spans_outs['old_phns'] + new_phns = phns_spans_outs['new_phns'] + span_to_repl = phns_spans_outs['span_to_repl'] + span_to_add = phns_spans_outs['span_to_add'] + + # 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替 + if target_lang in {'en', 'zh'}: + old_durs = eval_durs(old_phns, target_lang=source_lang) + else: + assert target_lang in {'en', 'zh'}, \ + "calculate duration_predict is not support for this language..." + + orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)] + + if duration_adjust: + d_factor = get_dur_adj_factor( + orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns) + d_factor = d_factor * 1.25 + else: + d_factor = 1 + + if target_lang in {'en', 'zh'}: + new_durs = eval_durs(new_phns, target_lang=target_lang) + else: + assert target_lang == "zh" or target_lang == "en", \ + "calculate duration_predict is not support for this language..." + + # duration 要是整数 + new_durs_adjusted = [int(np.ceil(d_factor * i)) for i in new_durs] + + new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]]) + old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]]) + dur_offset = new_span_dur_sum - old_span_dur_sum + new_mfa_start = mfa_start[:span_to_repl[0]] + new_mfa_end = mfa_end[:span_to_repl[0]] + + for dur in new_durs_adjusted[span_to_add[0]:span_to_add[1]]: + if len(new_mfa_end) == 0: + new_mfa_start.append(0) + new_mfa_end.append(dur) + else: + new_mfa_start.append(new_mfa_end[-1]) + new_mfa_end.append(new_mfa_end[-1] + dur) + + new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]] + new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]] + + # 3. get new wav + # 在原始句子后拼接 + if span_to_repl[0] >= len(mfa_start): + wav_left_idx = len(wav_org) + wav_right_idx = wav_left_idx + # 在原始句子中间替换 + else: + wav_left_idx = int(np.floor(mfa_start[span_to_repl[0]] * n_shift)) + wav_right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * n_shift)) + blank_wav = np.zeros( + (int(np.ceil(new_span_dur_sum * n_shift)), ), dtype=wav_org.dtype) + # 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定 + new_wav = np.concatenate( + [wav_org[:wav_left_idx], blank_wav, wav_org[wav_right_idx:]]) + + # 4. get old and new mel span to be mask + old_span_bdy = get_span_bdy( + mfa_start=mfa_start, mfa_end=mfa_end, span_to_repl=span_to_repl) + + new_span_bdy = get_span_bdy( + mfa_start=new_mfa_start, mfa_end=new_mfa_end, span_to_repl=span_to_add) + + # old_span_bdy, new_span_bdy 是帧级别的范围 + outs = {} + outs['new_wav'] = new_wav + outs['new_phns'] = new_phns + outs['new_mfa_start'] = new_mfa_start + outs['new_mfa_end'] = new_mfa_end + outs['old_span_bdy'] = old_span_bdy + outs['new_span_bdy'] = new_span_bdy + return outs + + +def prep_feats(wav_path: str, + mel_extractor, + vocab_phones, + erniesat_stat, + collate_fn, + old_str: str='', + new_str: str='', + source_lang: str='en', + target_lang: str='en', + duration_adjust: bool=True, + fs: int=24000, + n_shift: int=300, + mfa_version: str='v1' + ): + + with_dur_outs = prep_feats_with_dur( + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + source_lang=source_lang, + target_lang=target_lang, + duration_adjust=duration_adjust, + fs=fs, + n_shift=n_shift, + mfa_version=mfa_version + ) + + wav_name = os.path.basename(wav_path) + utt_id = wav_name.split('.')[0] + + wav = with_dur_outs['new_wav'] + phns = with_dur_outs['new_phns'] + mfa_start = with_dur_outs['new_mfa_start'] + mfa_end = with_dur_outs['new_mfa_end'] + old_span_bdy = with_dur_outs['old_span_bdy'] + new_span_bdy = with_dur_outs['new_span_bdy'] + span_bdy = np.array(new_span_bdy) + + mel = mel_extractor.get_log_mel_fbank(wav) + erniesat_mean, erniesat_std = np.load(erniesat_stat) + normed_mel = norm(mel, erniesat_mean, erniesat_std) + tmp_name = 'ernie_sat/' + get_tmp_name(text=old_str) + tmpbase = './tmp_dir/' + tmp_name + tmpbase = Path(tmpbase) + tmpbase.mkdir(parents=True, exist_ok=True) + + mel_path = tmpbase / 'mel.npy' + np.save(mel_path, normed_mel) + durations = [e - s for e, s in zip(mfa_end, mfa_start)] + text = _p2id(phns, vocab_phones) + + datum = { + "utt_id": utt_id, + "spk_id": 0, + "text": text, + "text_lengths": len(text), + "speech_lengths": len(normed_mel), + "durations": durations, + "speech": np.load(mel_path), + "align_start": mfa_start, + "align_end": mfa_end, + "span_bdy": span_bdy + } + + batch = collate_fn([datum]) + outs = dict() + outs['batch'] = batch + outs['old_span_bdy'] = old_span_bdy + outs['new_span_bdy'] = new_span_bdy + return outs + + +def get_mlm_output(wav_path: str, + erniesat_inference, + mel_extractor, + vocab_phones, + erniesat_stat, + collate_fn, + old_str: str='', + new_str: str='', + source_lang: str='en', + target_lang: str='en', + duration_adjust: bool=True, + fs: int=24000, + n_shift: int=300, + mfa_version: str='v1' ): + + prep_feats_outs = prep_feats( + wav_path=wav_path, + mel_extractor=mel_extractor, + vocab_phones=vocab_phones, + erniesat_stat=erniesat_stat, + collate_fn=collate_fn, + old_str=old_str, + new_str=new_str, + source_lang=source_lang, + target_lang=target_lang, + duration_adjust=duration_adjust, + fs=fs, + n_shift=n_shift, + mfa_version=mfa_version) + + batch = prep_feats_outs['batch'] + new_span_bdy = prep_feats_outs['new_span_bdy'] + old_span_bdy = prep_feats_outs['old_span_bdy'] + + out_mels = erniesat_inference( + speech=batch['speech'], + text=batch['text'], + masked_pos=batch['masked_pos'], + speech_mask=batch['speech_mask'], + text_mask=batch['text_mask'], + speech_seg_pos=batch['speech_seg_pos'], + text_seg_pos=batch['text_seg_pos'], + span_bdy=new_span_bdy) + + # 拼接音频 + output_feat = paddle.concat(x=out_mels, axis=0) + wav_org, _ = librosa.load(wav_path, sr=fs) + outs = dict() + outs['wav_org'] = wav_org + outs['output_feat'] = output_feat + outs['old_span_bdy'] = old_span_bdy + outs['new_span_bdy'] = new_span_bdy + + return outs + + +def get_wav(wav_path: str, + task_name, + voc_inference, + erniesat_inference, + mel_extractor, + vocab_phones, + erniesat_stat, + collate_fn, + source_lang: str='en', + target_lang: str='en', + old_str: str='', + new_str: str='', + duration_adjust: bool=True, + fs: int=24000, + n_shift: int=300, + mfa_version: str='v1'): + + outs = get_mlm_output( + wav_path=wav_path, + erniesat_inference=erniesat_inference, + mel_extractor=mel_extractor, + vocab_phones=vocab_phones, + erniesat_stat=erniesat_stat, + collate_fn=collate_fn, + old_str=old_str, + new_str=new_str, + source_lang=source_lang, + target_lang=target_lang, + duration_adjust=duration_adjust, + fs=fs, + n_shift=n_shift, + mfa_version=mfa_version) + + wav_org = outs['wav_org'] + output_feat = outs['output_feat'] + old_span_bdy = outs['old_span_bdy'] + new_span_bdy = outs['new_span_bdy'] + + masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]] + + with paddle.no_grad(): + alt_wav = voc_inference(masked_feat) + alt_wav = np.squeeze(alt_wav) + + old_time_bdy = [n_shift * x for x in old_span_bdy] + if task_name == 'edit': + wav_replaced = np.concatenate( + [wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]]) + else: + wav_replaced = alt_wav + + wav_dict = {"origin": wav_org, "output": wav_replaced} + return wav_dict + + +def ernie_sat_web(erniesat_config, + old_str, + new_str, + source_lang, + target_lang, + task_name, + erniesat_ckpt, + erniesat_stat, + phones_dict, + voc_config, + voc, + voc_ckpt, + voc_stat, + duration_adjust, + wav_path, + output_name, + mfa_version='v1' + ): + with open(erniesat_config) as f: + erniesat_config = CfgNode(yaml.safe_load(f)) + + # convert Chinese characters to pinyin + if source_lang == 'zh': + old_str = pypinyin.lazy_pinyin( + old_str, + neutral_tone_with_five=True, + style=pypinyin.Style.TONE3, + tone_sandhi=True) + old_str = ' '.join(old_str) + if target_lang == 'zh': + new_str = pypinyin.lazy_pinyin( + new_str, + neutral_tone_with_five=True, + style=pypinyin.Style.TONE3, + tone_sandhi=True) + new_str = ' '.join(new_str) + + if task_name == 'edit': + new_str = new_str + elif task_name == 'synthesize': + new_str = old_str + ' ' + new_str + else: + new_str = old_str + ' ' + new_str + print("new_str:", new_str) + + # Extractor + mel_extractor = LogMelFBank( + sr=erniesat_config.fs, + n_fft=erniesat_config.n_fft, + hop_length=erniesat_config.n_shift, + win_length=erniesat_config.win_length, + window=erniesat_config.window, + n_mels=erniesat_config.n_mels, + fmin=erniesat_config.fmin, + fmax=erniesat_config.fmax) + + collate_fn = build_erniesat_collate_fn( + mlm_prob=erniesat_config.mlm_prob, + mean_phn_span=erniesat_config.mean_phn_span, + seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm', + text_masking=False) + + vocab_phones = {} + + with open(phones_dict, 'rt') as f: + phn_id = [line.strip().split() for line in f.readlines()] + for phn, id in phn_id: + vocab_phones[phn] = int(id) + + # ernie sat model + erniesat_inference = get_am_inference( + am='erniesat_dataset', + am_config=erniesat_config, + am_ckpt=erniesat_ckpt, + am_stat=erniesat_stat, + phones_dict=phones_dict) + + with open(voc_config) as f: + voc_config = CfgNode(yaml.safe_load(f)) + + # vocoder + voc_inference = get_voc_inference( + voc=voc, + voc_config=voc_config, + voc_ckpt=voc_ckpt, + voc_stat=voc_stat) + + erniesat_stat = erniesat_stat + + wav_dict = get_wav( + wav_path=wav_path, + task_name=task_name, + voc_inference=voc_inference, + erniesat_inference=erniesat_inference, + mel_extractor=mel_extractor, + vocab_phones=vocab_phones, + erniesat_stat=erniesat_stat, + collate_fn=collate_fn, + source_lang=source_lang, + target_lang=target_lang, + old_str=old_str, + new_str=new_str, + duration_adjust=duration_adjust, + fs=erniesat_config.fs, + n_shift=erniesat_config.n_shift, + mfa_version=mfa_version) + + sf.write( + output_name, wav_dict['output'], samplerate=erniesat_config.fs) + return output_name + + +if __name__ == '__main__': + + erniesat_config = "source/model/erniesat_aishell3_ckpt_1.2.0/default.yaml" + erniesat_ckpt = "source/model/erniesat_aishell3_ckpt_1.2.0/snapshot_iter_289500.pdz" + erniesat_stat = "source/model/erniesat_aishell3_ckpt_1.2.0/speech_stats.npy" + phones_dict = "source/model/erniesat_aishell3_ckpt_1.2.0/phone_id_map.txt" + duration_adjust = True + + voc = "hifigan_aishell3" + voc_config = "source/model/hifigan_aishell3_ckpt_0.2.0/default.yaml" + voc_ckpt = "source/model/hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz" + voc_stat = "source/model/hifigan_aishell3_ckpt_0.2.0/feats_stats.npy" + + + old_str = "今天天气很好" + new_str = "今天心情很好" + source_lang = "zh" + target_lang = "zh" + task_name = "edit" + wav_path = "source/wav/SAT/upload/SSB03540428.wav" + output_name = "source/wav/SAT/out/demo_edit.wav" + + mfa_version='v2' + + ernie_sat_web(erniesat_config, + old_str, + new_str, + source_lang, + target_lang, + task_name, + erniesat_ckpt, + erniesat_stat, + phones_dict, + voc_config, + voc, + voc_ckpt, + voc_stat, + duration_adjust, + wav_path, + output_name, + mfa_version=mfa_version + ) + \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/finetune.py b/demos/speech_web/speech_server/src/finetune.py new file mode 100644 index 000000000..1c387151c --- /dev/null +++ b/demos/speech_web/speech_server/src/finetune.py @@ -0,0 +1,142 @@ +# +# GE2E 里面的函数会干扰这边的训练过程,引起错误 +# 单独运行此处的 finetune 微调过程 +# +import argparse +import os +import subprocess +# from src.ft.finetune_tool import finetune_model +# from ft.finetune_tool import finetune_model, synthesize + +def find_max_ckpt(model_path): + max_ckpt = 0 + for filename in os.listdir(model_path): + if filename.endswith('.pdz'): + files = filename[:-4] + a1, a2, it = files.split("_") + if int(it) > max_ckpt: + max_ckpt = int(it) + return max_ckpt + + +class FineTune: + def __init__(self, mfa_version='v1', pretrained_model_dir="source/model/fastspeech2_aishell3_ckpt_1.1.0"): + self.mfa_version = mfa_version + self.pretrained_model_dir = pretrained_model_dir + + def finetune(self, input_dir, exp_dir = 'temp', epoch=10, batch_size=2): + + mfa_dir = os.path.join(exp_dir, 'mfa_result') + dump_dir = os.path.join(exp_dir, 'dump') + output_dir = os.path.join(exp_dir, 'exp') + lang = "zh" + ngpu = 0 + + cmd = f""" + python src/ft/finetune_tool.py --input_dir {input_dir} \ + --pretrained_model_dir {self.pretrained_model_dir} \ + --mfa_dir {mfa_dir} \ + --dump_dir {dump_dir} \ + --output_dir {output_dir} \ + --lang {lang} \ + --ngpu {ngpu} \ + --epoch {epoch} \ + --batch_size {batch_size} \ + --mfa_version {self.mfa_version} + """ + + return self.run_cmd(cmd=cmd, output_name=exp_dir) + + + def synthesize(self, text, wav_name, out_wav_dir, exp_dir = 'tmp_dir'): + + # 合成测试 + pretrained_model_dir = self.pretrained_model_dir + print("exp_dir: ", exp_dir) + dump_dir = os.path.join(exp_dir, 'dump') + output_dir = os.path.join(exp_dir, 'exp') + text_path = os.path.join(exp_dir, 'sentences.txt') + lang = "zh" + + model_path = f"{output_dir}/checkpoints" + ckpt = find_max_ckpt(model_path) + + # 生成对应的语句 + with open(text_path, "w", encoding='utf8') as f: + f.write(wav_name+" "+text) + + lang = "zh" + spk_id = 0 + ngpu = 0 + am = "fastspeech2_aishell3" + am_config = f"{pretrained_model_dir}/default.yaml" + am_ckpt = f"{output_dir}/checkpoints/snapshot_iter_{ckpt}.pdz" + am_stat = f"{pretrained_model_dir}/speech_stats.npy" + speaker_dict = f"{dump_dir}/speaker_id_map.txt" + phones_dict = f"{dump_dir}/phone_id_map.txt" + tones_dict = None + voc = "hifigan_aishell3" + voc_config = "source/model/hifigan_aishell3_ckpt_0.2.0/default.yaml" + voc_ckpt = "source/model/hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz" + voc_stat = "source/model/hifigan_aishell3_ckpt_0.2.0/feats_stats.npy" + + cmd = f""" + python src/ft/synthesize.py \ + --am={am} \ + --am_config={am_config} \ + --am_ckpt={am_ckpt} \ + --am_stat={am_stat} \ + --voc={voc} \ + --voc_config={voc_config} \ + --voc_ckpt={voc_ckpt} \ + --voc_stat={voc_stat} \ + --lang={lang} \ + --text={text_path}\ + --output_dir={out_wav_dir} \ + --phones_dict={phones_dict} \ + --speaker_dict={speaker_dict} \ + --ngpu {ngpu} \ + --spk_id={spk_id} + """ + out_wav_path = os.path.join(out_wav_dir, wav_name) + return self.run_cmd(cmd, out_wav_path+'.wav') + + def run_cmd(self, cmd, output_name): + p = subprocess.Popen(cmd, shell=True) + res = p.wait() + print(cmd) + print("运行结果:", res) + if res == 0: + # 运行成功 + print(f"cmd 合成结果: {output_name}") + if os.path.exists(output_name): + return output_name + else: + # 合成的文件不存在 + return None + else: + # 运行失败 + return None + +if __name__ == '__main__': + ft_model = FineTune(mfa_version='v2') + + exp_dir = os.path.realpath("tmp_dir/finetune") + input_dir = os.path.realpath("source/wav/finetune/default") + output_dir = os.path.realpath("source/wav/finetune/out") + + ################################# + ######## 试验轮次验证 ############# + ################################# + lab = 1 + # 先删除数据 + cmd = f"rm -rf {exp_dir}" + os.system(cmd) + ft_model.finetune(input_dir=input_dir, exp_dir = exp_dir, epoch=10, batch_size=2) + + # 合成 + text = "今天的天气真不错" + wav_name = "demo" + str(lab) + "_a" + out_wav_dir = os.path.realpath("source/wav/finetune/out") + ft_model.synthesize(text, wav_name, out_wav_dir, exp_dir = exp_dir) + diff --git a/demos/speech_web/speech_server/src/ft/check_oov.py b/demos/speech_web/speech_server/src/ft/check_oov.py new file mode 100644 index 000000000..4d6854826 --- /dev/null +++ b/demos/speech_web/speech_server/src/ft/check_oov.py @@ -0,0 +1,125 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from typing import Dict +from typing import List +from typing import Union + + +def check_phone(label_file: Union[str, Path], + pinyin_phones: Dict[str, str], + mfa_phones: List[str], + am_phones: List[str], + oov_record: str="./oov_info.txt"): + """Check whether the phoneme corresponding to the audio text content + is in the phoneme list of the pretrained mfa model to ensure that the alignment is normal. + Check whether the phoneme corresponding to the audio text content + is in the phoneme list of the pretrained am model to ensure finetune (normalize) is normal. + + Args: + label_file (Union[str, Path]): label file, format: utt_id|phone seq + pinyin_phones (dict): pinyin to phones map dict + mfa_phones (list): the phone list of pretrained mfa model + am_phones (list): the phone list of pretrained mfa model + + Returns: + oov_words (list): oov words + oov_files (list): utt id list that exist oov + oov_file_words (dict): the oov file and oov phone in this file + """ + oov_words = [] + oov_files = [] + oov_file_words = {} + + with open(label_file, "r") as f: + for line in f.readlines(): + utt_id = line.split("|")[0] + transcription = line.strip().split("|")[1] + flag = 0 + temp_oov_words = [] + for word in transcription.split(" "): + if word not in pinyin_phones.keys(): + temp_oov_words.append(word) + flag = 1 + if word not in oov_words: + oov_words.append(word) + else: + for p in pinyin_phones[word]: + if p not in mfa_phones or p not in am_phones: + temp_oov_words.append(word) + flag = 1 + if word not in oov_words: + oov_words.append(word) + if flag == 1: + oov_files.append(utt_id) + oov_file_words[utt_id] = temp_oov_words + + if oov_record is not None: + with open(oov_record, "w") as fw: + fw.write("oov_words: " + str(oov_words) + "\n") + fw.write("oov_files: " + str(oov_files) + "\n") + fw.write("oov_file_words: " + str(oov_file_words) + "\n") + + return oov_words, oov_files, oov_file_words + + +def get_pinyin_phones(lexicon_file: Union[str, Path]): + # pinyin to phones + pinyin_phones = {} + with open(lexicon_file, "r") as f2: + for line in f2.readlines(): + line_list = line.strip().split(" ") + pinyin = line_list[0] + if line_list[1] == '': + phones = line_list[2:] + else: + phones = line_list[1:] + pinyin_phones[pinyin] = phones + + return pinyin_phones + + +def get_mfa_phone(mfa_phone_file: Union[str, Path]): + # get phones from pretrained mfa model (meta.yaml) + mfa_phones = [] + with open(mfa_phone_file, "r") as f: + for line in f.readlines(): + if line.startswith("-"): + phone = line.strip().split(" ")[-1] + mfa_phones.append(phone) + + return mfa_phones + + +def get_am_phone(am_phone_file: Union[str, Path]): + # get phones from pretrained am model (phone_id_map.txt) + am_phones = [] + with open(am_phone_file, "r") as f: + for line in f.readlines(): + phone = line.strip().split(" ")[0] + am_phones.append(phone) + + return am_phones + + +def get_check_result(label_file: Union[str, Path], + lexicon_file: Union[str, Path], + mfa_phone_file: Union[str, Path], + am_phone_file: Union[str, Path]): + pinyin_phones = get_pinyin_phones(lexicon_file) + mfa_phones = get_mfa_phone(mfa_phone_file) + am_phones = get_am_phone(am_phone_file) + oov_words, oov_files, oov_file_words = check_phone( + label_file, pinyin_phones, mfa_phones, am_phones) + return oov_words, oov_files, oov_file_words diff --git a/demos/speech_web/speech_server/src/ft/extract.py b/demos/speech_web/speech_server/src/ft/extract.py new file mode 100644 index 000000000..edd92420b --- /dev/null +++ b/demos/speech_web/speech_server/src/ft/extract.py @@ -0,0 +1,287 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import math +import os +from operator import itemgetter +from pathlib import Path +from typing import Dict +from typing import Union + +import jsonlines +import numpy as np +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.datasets.get_feats import Energy +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.get_feats import Pitch +from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur +from paddlespeech.t2s.datasets.preprocess_utils import merge_silence +from paddlespeech.t2s.exps.fastspeech2.preprocess import process_sentences + + +def read_stats(stats_file: Union[str, Path]): + scaler = StandardScaler() + scaler.mean_ = np.load(stats_file)[0] + scaler.scale_ = np.load(stats_file)[1] + scaler.n_features_in_ = scaler.mean_.shape[0] + return scaler + + +def get_stats(pretrained_model_dir: Path): + speech_stats_file = pretrained_model_dir / "speech_stats.npy" + pitch_stats_file = pretrained_model_dir / "pitch_stats.npy" + energy_stats_file = pretrained_model_dir / "energy_stats.npy" + speech_scaler = read_stats(speech_stats_file) + pitch_scaler = read_stats(pitch_stats_file) + energy_scaler = read_stats(energy_stats_file) + + return speech_scaler, pitch_scaler, energy_scaler + + +def get_map(duration_file: Union[str, Path], + dump_dir: Path, + pretrained_model_dir: Path): + """get phone map and speaker map, save on dump_dir + + Args: + duration_file (str): durantions.txt + dump_dir (Path): dump dir + pretrained_model_dir (Path): pretrained model dir + """ + # copy phone map file from pretrained model path + phones_dict = dump_dir / "phone_id_map.txt" + os.system("cp %s %s" % + (pretrained_model_dir / "phone_id_map.txt", phones_dict)) + + # create a new speaker map file, replace the previous speakers. + sentences, speaker_set = get_phn_dur(duration_file) + merge_silence(sentences) + speakers = sorted(list(speaker_set)) + num = len(speakers) + speaker_dict = dump_dir / "speaker_id_map.txt" + with open(speaker_dict, 'w') as f, open(pretrained_model_dir / + "speaker_id_map.txt", 'r') as fr: + for i, spk in enumerate(speakers): + f.write(spk + ' ' + str(i) + '\n') + for line in fr.readlines(): + spk_id = line.strip().split(" ")[-1] + if int(spk_id) >= num: + f.write(line) + + vocab_phones = {} + with open(phones_dict, 'rt') as f: + phn_id = [line.strip().split() for line in f.readlines()] + for phn, id in phn_id: + vocab_phones[phn] = int(id) + + vocab_speaker = {} + with open(speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + for spk, id in spk_id: + vocab_speaker[spk] = int(id) + + return sentences, vocab_phones, vocab_speaker + + +def get_extractor(config): + # Extractor + mel_extractor = LogMelFBank( + sr=config.fs, + n_fft=config.n_fft, + hop_length=config.n_shift, + win_length=config.win_length, + window=config.window, + n_mels=config.n_mels, + fmin=config.fmin, + fmax=config.fmax) + pitch_extractor = Pitch( + sr=config.fs, + hop_length=config.n_shift, + f0min=config.f0min, + f0max=config.f0max) + energy_extractor = Energy( + n_fft=config.n_fft, + hop_length=config.n_shift, + win_length=config.win_length, + window=config.window) + + return mel_extractor, pitch_extractor, energy_extractor + + +def normalize(speech_scaler, + pitch_scaler, + energy_scaler, + vocab_phones: Dict, + vocab_speaker: Dict, + raw_dump_dir: Path, + type: str): + + dumpdir = raw_dump_dir / type / "norm" + dumpdir = Path(dumpdir).expanduser() + dumpdir.mkdir(parents=True, exist_ok=True) + + # get dataset + metadata_file = raw_dump_dir / type / "raw" / "metadata.jsonl" + with jsonlines.open(metadata_file, 'r') as reader: + metadata = list(reader) + dataset = DataTable( + metadata, + converters={ + "speech": np.load, + "pitch": np.load, + "energy": np.load, + }) + logging.info(f"The number of files = {len(dataset)}.") + + # process each file + output_metadata = [] + + for item in tqdm(dataset): + utt_id = item['utt_id'] + speech = item['speech'] + pitch = item['pitch'] + energy = item['energy'] + # normalize + speech = speech_scaler.transform(speech) + speech_dir = dumpdir / "data_speech" + speech_dir.mkdir(parents=True, exist_ok=True) + speech_path = speech_dir / f"{utt_id}_speech.npy" + np.save(speech_path, speech.astype(np.float32), allow_pickle=False) + + pitch = pitch_scaler.transform(pitch) + pitch_dir = dumpdir / "data_pitch" + pitch_dir.mkdir(parents=True, exist_ok=True) + pitch_path = pitch_dir / f"{utt_id}_pitch.npy" + np.save(pitch_path, pitch.astype(np.float32), allow_pickle=False) + + energy = energy_scaler.transform(energy) + energy_dir = dumpdir / "data_energy" + energy_dir.mkdir(parents=True, exist_ok=True) + energy_path = energy_dir / f"{utt_id}_energy.npy" + np.save(energy_path, energy.astype(np.float32), allow_pickle=False) + + phone_ids = [vocab_phones[p] for p in item['phones']] + spk_id = vocab_speaker[item["speaker"]] + record = { + "utt_id": item['utt_id'], + "spk_id": spk_id, + "text": phone_ids, + "text_lengths": item['text_lengths'], + "speech_lengths": item['speech_lengths'], + "durations": item['durations'], + "speech": str(speech_path), + "pitch": str(pitch_path), + "energy": str(energy_path) + } + # add spk_emb for voice cloning + if "spk_emb" in item: + record["spk_emb"] = str(item["spk_emb"]) + + output_metadata.append(record) + output_metadata.sort(key=itemgetter('utt_id')) + output_metadata_path = Path(dumpdir) / "metadata.jsonl" + with jsonlines.open(output_metadata_path, 'w') as writer: + for item in output_metadata: + writer.write(item) + logging.info(f"metadata dumped into {output_metadata_path}") + + +def extract_feature(duration_file: str, + config, + input_dir: Path, + dump_dir: Path, + pretrained_model_dir: Path): + + sentences, vocab_phones, vocab_speaker = get_map(duration_file, dump_dir, + pretrained_model_dir) + mel_extractor, pitch_extractor, energy_extractor = get_extractor(config) + + wav_files = sorted(list((input_dir).rglob("*.wav"))) + # split data into 3 sections, train: 80%, dev: 10%, test: 10% + num_train = math.ceil(len(wav_files) * 0.8) + num_dev = math.ceil(len(wav_files) * 0.1) + print(num_train, num_dev) + + train_wav_files = wav_files[:num_train] + dev_wav_files = wav_files[num_train:num_train + num_dev] + test_wav_files = wav_files[num_train + num_dev:] + + train_dump_dir = dump_dir / "train" / "raw" + train_dump_dir.mkdir(parents=True, exist_ok=True) + dev_dump_dir = dump_dir / "dev" / "raw" + dev_dump_dir.mkdir(parents=True, exist_ok=True) + test_dump_dir = dump_dir / "test" / "raw" + test_dump_dir.mkdir(parents=True, exist_ok=True) + + # process for the 3 sections + num_cpu = 4 + cut_sil = True + spk_emb_dir = None + write_metadata_method = "w" + speech_scaler, pitch_scaler, energy_scaler = get_stats(pretrained_model_dir) + + if train_wav_files: + process_sentences( + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, + nprocs=num_cpu, + cut_sil=cut_sil, + spk_emb_dir=spk_emb_dir, + write_metadata_method=write_metadata_method) + # norm + normalize(speech_scaler, pitch_scaler, energy_scaler, vocab_phones, + vocab_speaker, dump_dir, "train") + + if dev_wav_files: + process_sentences( + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, + nprocs=num_cpu, + cut_sil=cut_sil, + spk_emb_dir=spk_emb_dir, + write_metadata_method=write_metadata_method) + # norm + normalize(speech_scaler, pitch_scaler, energy_scaler, vocab_phones, + vocab_speaker, dump_dir, "dev") + + if test_wav_files: + process_sentences( + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, + nprocs=num_cpu, + cut_sil=cut_sil, + spk_emb_dir=spk_emb_dir, + write_metadata_method=write_metadata_method) + + # norm + normalize(speech_scaler, pitch_scaler, energy_scaler, vocab_phones, + vocab_speaker, dump_dir, "test") diff --git a/demos/speech_web/speech_server/src/ft/finetune_tool.py b/demos/speech_web/speech_server/src/ft/finetune_tool.py new file mode 100644 index 000000000..9232c9313 --- /dev/null +++ b/demos/speech_web/speech_server/src/ft/finetune_tool.py @@ -0,0 +1,316 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path +from typing import Union +import yaml +from paddle import distributed as dist +from yacs.config import CfgNode +import argparse +from pathlib import Path +import paddle +import soundfile as sf +import yaml +from timer import timer +from yacs.config import CfgNode + +from paddlespeech.t2s.exps.fastspeech2.train import train_sp + +# from .check_oov import get_check_result +# from .extract import extract_feature +# from .label_process import get_single_label +# from .prepare_env import generate_finetune_env + +from check_oov import get_check_result +from extract import extract_feature +from label_process import get_single_label +from prepare_env import generate_finetune_env + +from utils.gen_duration_from_textgrid import gen_duration_from_textgrid + +DICT_EN = 'source/tools/aligner/cmudict-0.7b' +DICT_EN_v2 = 'source/tools/aligner/cmudict-0.7b.dict' +DICT_ZH = 'source/tools/aligner/simple.lexicon' +DICT_ZH_v2 = 'source/tools/aligner/simple.dict' +MODEL_DIR_EN = 'source/tools/aligner/vctk_model.zip' +MODEL_DIR_ZH = 'source/tools/aligner/aishell3_model.zip' +MFA_PHONE_EN = 'source/tools/aligner/vctk_model/meta.yaml' +MFA_PHONE_ZH = 'source/tools/aligner/aishell3_model/meta.yaml' +MFA_PATH = 'source/tools/montreal-forced-aligner/bin' +os.environ['PATH'] = MFA_PATH + '/:' + os.environ['PATH'] + + +class TrainArgs(): + def __init__(self, ngpu, config_file, dump_dir: Path, output_dir: Path): + self.config = str(config_file) + self.train_metadata = str(dump_dir / "train/norm/metadata.jsonl") + self.dev_metadata = str(dump_dir / "dev/norm/metadata.jsonl") + self.output_dir = str(output_dir) + self.ngpu = ngpu + self.phones_dict = str(dump_dir / "phone_id_map.txt") + self.speaker_dict = str(dump_dir / "speaker_id_map.txt") + self.voice_cloning = False + + +def get_mfa_result( + input_dir: Union[str, Path], + mfa_dir: Union[str, Path], + lang: str='en', + mfa_version='v1'): + """get mfa result + + Args: + input_dir (Union[str, Path]): input dir including wav file and label + mfa_dir (Union[str, Path]): mfa result dir + lang (str, optional): input audio language. Defaults to 'en'. + """ + input_dir = str(input_dir).replace("/newdir", "") + # MFA + if mfa_version == 'v1': + if lang == 'en': + DICT = DICT_EN + MODEL_DIR = MODEL_DIR_EN + + elif lang == 'zh': + DICT = DICT_ZH + MODEL_DIR = MODEL_DIR_ZH + else: + print('please input right lang!!') + + CMD = 'mfa_align' + ' ' + str( + input_dir) + ' ' + DICT + ' ' + MODEL_DIR + ' ' + str(mfa_dir) + os.system(CMD) + else: + if lang == 'en': + DICT = DICT_EN_v2 + MODEL_DIR = MODEL_DIR_EN + + elif lang == 'zh': + DICT = DICT_ZH_v2 + MODEL_DIR = MODEL_DIR_ZH + else: + print('please input right lang!!') + + CMD = 'mfa align' + ' ' + str( + input_dir) + ' ' + DICT + ' ' + MODEL_DIR + ' ' + str(mfa_dir) + os.system(CMD) + + +def finetune_model(input_dir, + pretrained_model_dir, + mfa_dir, + dump_dir, + lang, + output_dir, + ngpu, + epoch, + batch_size, + mfa_version='v1'): + fs = 24000 + n_shift = 300 + input_dir = Path(input_dir).expanduser() + mfa_dir = Path(mfa_dir).expanduser() + mfa_dir.mkdir(parents=True, exist_ok=True) + dump_dir = Path(dump_dir).expanduser() + dump_dir.mkdir(parents=True, exist_ok=True) + output_dir = Path(output_dir).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + pretrained_model_dir = Path(pretrained_model_dir).expanduser() + + # read config + config_file = pretrained_model_dir / "default.yaml" + print("config_path: ") + print(f"########### { config_file } ###########") + with open(config_file) as f: + config = CfgNode(yaml.safe_load(f)) + config.max_epoch = config.max_epoch + epoch + if batch_size > 0: + config.batch_size = batch_size + + if lang == 'en': + lexicon_file = DICT_EN + mfa_phone_file = MFA_PHONE_EN + elif lang == 'zh': + lexicon_file = DICT_ZH + mfa_phone_file = MFA_PHONE_ZH + else: + print('please input right lang!!') + am_phone_file = pretrained_model_dir / "phone_id_map.txt" + label_file = input_dir / "labels.txt" + + #check phone for mfa and am finetune + oov_words, oov_files, oov_file_words = get_check_result( + label_file, lexicon_file, mfa_phone_file, am_phone_file) + input_dir = get_single_label(label_file, oov_files, input_dir) + + # get mfa result + print("input_dir: ", input_dir) + get_mfa_result(input_dir, mfa_dir, lang, mfa_version=mfa_version) + + # # generate durations.txt + duration_file = "./durations.txt" + print("mfa_dir: ", mfa_dir) + gen_duration_from_textgrid(mfa_dir, duration_file, fs, n_shift) + + # generate phone and speaker map files + extract_feature(duration_file, config, input_dir, dump_dir, + pretrained_model_dir) + + # create finetune env + generate_finetune_env(output_dir, pretrained_model_dir) + + # create a new args for training + train_args = TrainArgs(ngpu, config_file, dump_dir, output_dir) + + # finetune models + # dispatch + if ngpu > 1: + dist.spawn(train_sp, (train_args, config), nprocs=ngpu) + else: + train_sp(train_args, config) + return output_dir + +# 合成 + + + +if __name__ == '__main__': + # parse config and args + parser = argparse.ArgumentParser( + description="Preprocess audio and then extract features.") + + parser.add_argument( + "--input_dir", + type=str, + help="directory containing audio and label file") + + parser.add_argument( + "--pretrained_model_dir", + type=str, + help="Path to pretrained model") + + parser.add_argument( + "--mfa_dir", + type=str, + default="./mfa_result", + help="directory to save aligned files") + + parser.add_argument( + "--dump_dir", + type=str, + default="./dump", + help="directory to save feature files and metadata.") + + parser.add_argument( + "--output_dir", + type=str, + default="./exp/default/", + help="directory to save finetune model.") + + parser.add_argument( + '--lang', + type=str, + default='zh', + choices=['zh', 'en'], + help='Choose input audio language. zh or en') + + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.") + + parser.add_argument("--epoch", type=int, default=100, help="finetune epoch") + + parser.add_argument( + "--batch_size", + type=int, + default=-1, + help="batch size, default -1 means same as pretrained model") + + parser.add_argument( + "--mfa_version", + type=str, + default='v1', + help="mfa version , you can choose v1 or v2") + + args = parser.parse_args() + + finetune_model(input_dir=args.input_dir, + pretrained_model_dir=args.pretrained_model_dir, + mfa_dir=args.mfa_dir, + dump_dir=args.dump_dir, + lang=args.lang, + output_dir=args.output_dir, + ngpu=args.ngpu, + epoch=args.epoch, + batch_size=args.batch_size, + mfa_version=args.mfa_version) + + + # 10 句话 finetune 测试 + # input_dir = "source/wav/finetune/default" + # pretrained_model_dir = "source/model/fastspeech2_aishell3_ckpt_1.1.0" + # mfa_dir = "tmp_dir/finetune/mfa" + # dump_dir = "tmp_dir/finetune/dump" + # lang = "zh" + # output_dir = "tmp_dir/finetune/out" + # ngpu = 0 + # epoch = 2 + # batch_size = 2 + # mfa_version = 'v2' + # 微调 + # finetune_model(input_dir, + # pretrained_model_dir, + # mfa_dir, + # dump_dir, + # lang, + # output_dir, + # ngpu, + # epoch, + # batch_size, + # mfa_version=mfa_version) + + # # 合成测试 + # text = "source/wav/finetune/test.txt" + + # lang = "zh" + # spk_id = 0 + # am = "fastspeech2_aishell3" + # am_config = f"{pretrained_model_dir}/default.yaml" + # am_ckpt = f"{output_dir}/checkpoints/snapshot_iter_96408.pdz" + # am_stat = f"{pretrained_model_dir}/speech_stats.npy" + # speaker_dict = f"{dump_dir}/speaker_id_map.txt" + # phones_dict = f"{dump_dir}/phone_id_map.txt" + # tones_dict = None + # voc = "hifigan_aishell3" + # voc_config = "source/model/hifigan_aishell3_ckpt_0.2.0/default.yaml" + # voc_ckpt = "source/model/hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz" + # voc_stat = "source/model/hifigan_aishell3_ckpt_0.2.0/feats_stats.npy" + + # wav_output_dir = "source/wav/finetune/out" + + # synthesize(text, + # wav_output_dir, + # lang, + # spk_id, + # am, + # am_config, + # am_ckpt, + # am_stat, + # speaker_dict, + # phones_dict, + # tones_dict, + # voc, + # voc_config, + # voc_ckpt, + # voc_stat + # ) diff --git a/demos/speech_web/speech_server/src/ft/label_process.py b/demos/speech_web/speech_server/src/ft/label_process.py new file mode 100644 index 000000000..711dde4b6 --- /dev/null +++ b/demos/speech_web/speech_server/src/ft/label_process.py @@ -0,0 +1,63 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path +from typing import List +from typing import Union + + +def change_baker_label(baker_label_file: Union[str, Path], + out_label_file: Union[str, Path]): + """change baker label file to regular label file + + Args: + baker_label_file (Union[str, Path]): Original baker label file + out_label_file (Union[str, Path]): regular label file + """ + with open(baker_label_file) as f: + lines = f.readlines() + + with open(out_label_file, "w") as fw: + for i in range(0, len(lines), 2): + utt_id = lines[i].split()[0] + transcription = lines[i + 1].strip() + fw.write(utt_id + "|" + transcription + "\n") + + +def get_single_label(label_file: Union[str, Path], + oov_files: List[Union[str, Path]], + input_dir: Union[str, Path]): + """Divide the label file into individual files according to label_file + + Args: + label_file (str or Path): label file, format: utt_id|phones id + input_dir (Path): input dir including audios + """ + input_dir = Path(input_dir).expanduser() + new_dir = input_dir / "newdir" + new_dir.mkdir(parents=True, exist_ok=True) + + with open(label_file, "r") as f: + for line in f.readlines(): + utt_id = line.split("|")[0] + if utt_id not in oov_files: + transcription = line.split("|")[1].strip() + wav_file = str(input_dir) + "/" + utt_id + ".wav" + new_wav_file = str(new_dir) + "/" + utt_id + ".wav" + os.system("cp %s %s" % (wav_file, new_wav_file)) + single_file = str(new_dir) + "/" + utt_id + ".txt" + with open(single_file, "w") as fw: + fw.write(transcription) + + return new_dir diff --git a/demos/speech_web/speech_server/src/ft/prepare_env.py b/demos/speech_web/speech_server/src/ft/prepare_env.py new file mode 100644 index 000000000..f2166ff1b --- /dev/null +++ b/demos/speech_web/speech_server/src/ft/prepare_env.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path + + +def generate_finetune_env(output_dir: Path, pretrained_model_dir: Path): + + output_dir = output_dir / "checkpoints/" + output_dir = output_dir.resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + model_path = sorted(list((pretrained_model_dir).rglob("*.pdz")))[0] + model_path = model_path.resolve() + iter = int(str(model_path).split("_")[-1].split(".")[0]) + model_file = str(model_path).split("/")[-1] + + os.system("cp %s %s" % (model_path, output_dir)) + + records_file = output_dir / "records.jsonl" + with open(records_file, "w") as f: + line = "\"time\": \"2022-08-06 07:51:53.463650\", \"path\": \"%s\", \"iteration\": %d" % ( + str(output_dir / model_file), iter) + f.write("{" + line + "}" + "\n") diff --git a/demos/speech_web/speech_server/src/ft/synthesize.py b/demos/speech_web/speech_server/src/ft/synthesize.py new file mode 100644 index 000000000..9ce8286fb --- /dev/null +++ b/demos/speech_web/speech_server/src/ft/synthesize.py @@ -0,0 +1,262 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from pathlib import Path + +import paddle +import soundfile as sf +import yaml +from timer import timer +from yacs.config import CfgNode + +from paddlespeech.t2s.exps.syn_utils import am_to_static +from paddlespeech.t2s.exps.syn_utils import get_am_inference +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.exps.syn_utils import get_voc_inference +from paddlespeech.t2s.exps.syn_utils import run_frontend +from paddlespeech.t2s.exps.syn_utils import voc_to_static + + +def evaluate(args): + + # Init body. + with open(args.am_config) as f: + am_config = CfgNode(yaml.safe_load(f)) + with open(args.voc_config) as f: + voc_config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(am_config) + print(voc_config) + + sentences = get_sentences(text_file=args.text, lang=args.lang) + + # frontend + frontend = get_frontend( + lang=args.lang, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict) + print("frontend done!") + + # acoustic model + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + + am_inference = get_am_inference( + am=args.am, + am_config=am_config, + am_ckpt=args.am_ckpt, + am_stat=args.am_stat, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict, + speaker_dict=args.speaker_dict) + print("acoustic model done!") + # vocoder + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) + print("voc done!") + + # whether dygraph to static + if args.inference_dir: + # acoustic model + am_inference = am_to_static( + am_inference=am_inference, + am=args.am, + inference_dir=args.inference_dir, + speaker_dict=args.speaker_dict) + # vocoder + voc_inference = voc_to_static( + voc_inference=voc_inference, + voc=args.voc, + inference_dir=args.inference_dir) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + merge_sentences = False + # Avoid not stopping at the end of a sub sentence when tacotron2_ljspeech dygraph to static graph + # but still not stopping in the end (NOTE by yuantian01 Feb 9 2022) + if am_name == 'tacotron2': + merge_sentences = True + + get_tone_ids = False + if am_name == 'speedyspeech': + get_tone_ids = True + + N = 0 + T = 0 + for utt_id, sentence in sentences: + with timer() as t: + frontend_dict = run_frontend( + frontend=frontend, + text=sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids, + lang=args.lang) + phone_ids = frontend_dict['phone_ids'] + with paddle.no_grad(): + flags = 0 + for i in range(len(phone_ids)): + part_phone_ids = phone_ids[i] + # acoustic model + if am_name == 'fastspeech2': + # multi speaker + if am_dataset in {"aishell3", "vctk", "mix"}: + spk_id = paddle.to_tensor(args.spk_id) + mel = am_inference(part_phone_ids, spk_id) + else: + mel = am_inference(part_phone_ids) + elif am_name == 'speedyspeech': + part_tone_ids = frontend_dict['tone_ids'][i] + if am_dataset in {"aishell3", "vctk", "mix"}: + spk_id = paddle.to_tensor(args.spk_id) + mel = am_inference(part_phone_ids, part_tone_ids, + spk_id) + else: + mel = am_inference(part_phone_ids, part_tone_ids) + elif am_name == 'tacotron2': + mel = am_inference(part_phone_ids) + # vocoder + wav = voc_inference(mel) + if flags == 0: + wav_all = wav + flags = 1 + else: + wav_all = paddle.concat([wav_all, wav]) + wav = wav_all.numpy() + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + rtf = am_config.fs / speed + print( + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + sf.write( + str(output_dir / (utt_id + ".wav")), wav, samplerate=am_config.fs) + print(f"{utt_id} done!") + print(f"generation speed: {N / T}Hz, RTF: {am_config.fs / (N / T) }") + + +def parse_args(): + # parse args and config + parser = argparse.ArgumentParser( + description="Synthesize with acoustic model & vocoder") + # acoustic model + parser.add_argument( + '--am', + type=str, + default='fastspeech2_csmsc', + choices=[ + 'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc', + 'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk', + 'tacotron2_csmsc', 'tacotron2_ljspeech', 'fastspeech2_mix' + ], + help='Choose acoustic model type of tts task.') + parser.add_argument( + '--am_config', type=str, default=None, help='Config of acoustic model.') + parser.add_argument( + '--am_ckpt', + type=str, + default=None, + help='Checkpoint file of acoustic model.') + parser.add_argument( + "--am_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training acoustic model." + ) + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--tones_dict", type=str, default=None, help="tone vocabulary file.") + parser.add_argument( + "--speaker_dict", type=str, default=None, help="speaker id map file.") + parser.add_argument( + '--spk_id', + type=int, + default=0, + help='spk id for multi speaker acoustic model') + # vocoder + parser.add_argument( + '--voc', + type=str, + default='pwgan_csmsc', + choices=[ + 'pwgan_csmsc', + 'pwgan_ljspeech', + 'pwgan_aishell3', + 'pwgan_vctk', + 'mb_melgan_csmsc', + 'style_melgan_csmsc', + 'hifigan_csmsc', + 'hifigan_ljspeech', + 'hifigan_aishell3', + 'hifigan_vctk', + 'wavernn_csmsc', + ], + help='Choose vocoder type of tts task.') + parser.add_argument( + '--voc_config', type=str, default=None, help='Config of voc.') + parser.add_argument( + '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') + parser.add_argument( + "--voc_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training voc." + ) + # other + parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en or mix') + + parser.add_argument( + "--inference_dir", + type=str, + default=None, + help="dir to save inference models") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line.") + parser.add_argument("--output_dir", type=str, help="output dir.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/demos/speech_web/speech_server/src/ge2e_clone.py b/demos/speech_web/speech_server/src/ge2e_clone.py new file mode 100644 index 000000000..166ece8d9 --- /dev/null +++ b/demos/speech_web/speech_server/src/ge2e_clone.py @@ -0,0 +1,130 @@ +""" +G2p Voice Clone +""" + +import argparse +import os +from pathlib import Path +from paddlespeech.t2s.modules.normalizer import ZScore +import numpy as np +import paddle +import soundfile as sf +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.frontend.zh_frontend import Frontend +from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor +from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder +from paddlespeech.utils.dynamic_import import dynamic_import + +class VoiceCloneGE2E(): + def __init__(self): + # 设置预训练模型的路径和其他变量 + self.model_alias = { + # acoustic model + "fastspeech2": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2", + "fastspeech2_inference": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", + # voc + "pwgan": + "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", + "pwgan_inference": + "paddlespeech.t2s.models.parallel_wavegan:PWGInference", + } + # am + self.am = "fastspeech2_aishell3" + self.am_config = "source/model/fastspeech2_nosil_aishell3_vc1_ckpt_0.5/default.yaml" + self.am_ckpt = "source/model/fastspeech2_nosil_aishell3_vc1_ckpt_0.5/snapshot_iter_96400.pdz" + self.am_stat = "source/model/fastspeech2_nosil_aishell3_vc1_ckpt_0.5/speech_stats.npy" + self.phones_dict = "source/model/fastspeech2_nosil_aishell3_vc1_ckpt_0.5/phone_id_map.txt" + # voc + self.voc = "pwgan_aishell3" + self.voc_config = "source/model/pwg_aishell3_ckpt_0.5/default.yaml" + self.voc_ckpt = "source/model/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz" + self.voc_stat = "source/model/pwg_aishell3_ckpt_0.5/feats_stats.npy" + # ge2e + self.ge2e_params_path = "source/model/ge2e_ckpt_0.3/step-3000000.pdparams" + with open(self.am_config) as f: + self.am_config = CfgNode(yaml.safe_load(f)) + with open(self.voc_config) as f: + self.voc_config = CfgNode(yaml.safe_load(f)) + + self.p = SpeakerVerificationPreprocessor( + sampling_rate=16000, + audio_norm_target_dBFS=-30, + vad_window_length=30, + vad_moving_average_width=8, + vad_max_silence_length=6, + mel_window_length=25, + mel_window_step=10, + n_mels=40, + partial_n_frames=160, + min_pad_coverage=0.75, + partial_overlap_ratio=0.5 + ) + self.speaker_encoder = LSTMSpeakerEncoder( + n_mels=40, num_layers=3, hidden_size=256, output_size=256 + ) + self.speaker_encoder.set_state_dict(paddle.load(self.ge2e_params_path)) + self.speaker_encoder.eval() + + with open(self.phones_dict, "r") as f: + self.phn_id = [line.strip().split() for line in f.readlines()] + self.vocab_size = len(self.phn_id) + + self.frontend = Frontend(phone_vocab_path=self.phones_dict) + + # am + am_name = "fastspeech2" + am_class = dynamic_import(am_name, self.model_alias) + print(self.am_config.n_mels) + self.am = am_class( + idim=self.vocab_size, odim=self.am_config.n_mels, spk_num=None, **self.am_config["model"]) + self.am_inference_class = dynamic_import(am_name + '_inference', self.model_alias) + self.am.set_state_dict(paddle.load(self.am_ckpt)["main_params"]) + self.am.eval() + + am_mu, am_std = np.load(self.am_stat) + am_mu = paddle.to_tensor(am_mu) + am_std = paddle.to_tensor(am_std) + self.am_normalizer = ZScore(am_mu, am_std) + self.am_inference = self.am_inference_class(self.am_normalizer, self.am) + self.am_inference.eval() + + # voc + voc_name = "pwgan" + voc_class = dynamic_import(voc_name, self.model_alias) + voc_inference_class = dynamic_import(voc_name + '_inference', self.model_alias) + self.voc = voc_class(**self.voc_config["generator_params"]) + self.voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"]) + self.voc.remove_weight_norm() + self.voc.eval() + voc_mu, voc_std = np.load(self.voc_stat) + voc_mu = paddle.to_tensor(voc_mu) + voc_std = paddle.to_tensor(voc_std) + voc_normalizer = ZScore(voc_mu, voc_std) + self.voc_inference = voc_inference_class(voc_normalizer, self.voc) + self.voc_inference.eval() + + def vc(self, text, input_wav, out_wav): + + input_ids = self.frontend.get_input_ids(text, merge_sentences=True) + phone_ids = input_ids["phone_ids"][0] + mel_sequences = self.p.extract_mel_partials(self.p.preprocess_wav(input_wav)) + with paddle.no_grad(): + spk_emb = self.speaker_encoder.embed_utterance( + paddle.to_tensor(mel_sequences)) + + with paddle.no_grad(): + wav = self.voc_inference(self.am_inference(phone_ids, spk_emb=spk_emb)) + sf.write(out_wav, wav.numpy(), samplerate=self.am_config.fs) + return True + + +if __name__ == '__main__': + voiceclone = VoiceCloneGE2E() + text = "测试一下你的合成效果" + input_wav = "wav/009901.wav" + out_wav = "wav/9901_clone.wav" + voiceclone.vc(text, input_wav, out_wav) \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/tdnn_clone.py b/demos/speech_web/speech_server/src/tdnn_clone.py new file mode 100644 index 000000000..b75b659d8 --- /dev/null +++ b/demos/speech_web/speech_server/src/tdnn_clone.py @@ -0,0 +1,113 @@ +""" +G2p Voice Clone +""" + +import os +from paddlespeech.t2s.modules.normalizer import ZScore +import numpy as np +import paddle +import soundfile as sf +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.frontend.zh_frontend import Frontend +from paddlespeech.utils.dynamic_import import dynamic_import +from paddlespeech.cli.vector import VectorExecutor + + + +model_alias = { + # acoustic model + "fastspeech2": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2", + "fastspeech2_inference": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", + # voc + "pwgan": + "paddlespeech.t2s.models.parallel_wavegan:PWGGenerator", + "pwgan_inference": + "paddlespeech.t2s.models.parallel_wavegan:PWGInference", +} + + +# 设置预训练模型的路径和其他变量 +# am + + +class VoiceCloneTDNN(): + def __init__(self): + + self.am = "fastspeech2_aishell3" + self.am_config = "source/model/fastspeech2_aishell3_ckpt_vc2_1.2.0/default.yaml" + self.am_ckpt = "source/model/fastspeech2_aishell3_ckpt_vc2_1.2.0/snapshot_iter_96400.pdz" + self.am_stat = "source/model/fastspeech2_aishell3_ckpt_vc2_1.2.0/speech_stats.npy" + self.phones_dict = "source/model/fastspeech2_aishell3_ckpt_vc2_1.2.0/phone_id_map.txt" + # voc + self.voc = "pwgan_aishell3" + self.voc_config = "source/model/pwg_aishell3_ckpt_0.5/default.yaml" + self.voc_ckpt = "source/model/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz" + self.voc_stat = "source/model/pwg_aishell3_ckpt_0.5/feats_stats.npy" + + with open(self.am_config) as f: + self.am_config = CfgNode(yaml.safe_load(f)) + with open(self.voc_config) as f: + self.voc_config = CfgNode(yaml.safe_load(f)) + self.vec_executor = VectorExecutor() + + + with open(self.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + + self.frontend = Frontend(phone_vocab_path=self.phones_dict) + + # am + am_name = "fastspeech2" + am_class = dynamic_import(am_name, model_alias) + print(self.am_config.n_mels) + self.am = am_class( + idim=vocab_size, odim=self.am_config.n_mels, spk_num=None, **self.am_config["model"]) + self.am_inference_class = dynamic_import(am_name + '_inference', model_alias) + self.am.set_state_dict(paddle.load(self.am_ckpt)["main_params"]) + self.am.eval() + + am_mu, am_std = np.load(self.am_stat) + am_mu = paddle.to_tensor(am_mu) + am_std = paddle.to_tensor(am_std) + self.am_normalizer = ZScore(am_mu, am_std) + self.am_inference = self.am_inference_class(self.am_normalizer, self.am) + self.am_inference.eval() + + # voc + voc_name = "pwgan" + voc_class = dynamic_import(voc_name, model_alias) + voc_inference_class = dynamic_import(voc_name + '_inference', model_alias) + self.voc = voc_class(**self.voc_config["generator_params"]) + self.voc.set_state_dict(paddle.load(self.voc_ckpt)["generator_params"]) + self.voc.remove_weight_norm() + self.voc.eval() + voc_mu, voc_std = np.load(self.voc_stat) + voc_mu = paddle.to_tensor(voc_mu) + voc_std = paddle.to_tensor(voc_std) + voc_normalizer = ZScore(voc_mu, voc_std) + self.voc_inference = voc_inference_class(voc_normalizer, self.voc) + self.voc_inference.eval() + + def vc(self, text, input_wav, out_wav): + input_ids = self.frontend.get_input_ids(text, merge_sentences=True) + phone_ids = input_ids["phone_ids"][0] + spk_emb = self.vec_executor(audio_file=input_wav, force_yes=True) + spk_emb = paddle.to_tensor(spk_emb) + + with paddle.no_grad(): + wav = self.voc_inference(self.am_inference(phone_ids, spk_emb=spk_emb)) + sf.write(out_wav, wav.numpy(), samplerate=self.am_config.fs) + return True + + +if __name__ == '__main__': + voiceclone =VoiceCloneTDNN() + text = "测试一下你的合成效果" + input_wav = os.path.realpath("source/wav/test/009901.wav") + out_wav = os.path.realpath("source/wav/test/9901_clone.wav") + voiceclone.vc(text, input_wav, out_wav) \ No newline at end of file diff --git a/demos/speech_web/speech_server/vc.py b/demos/speech_web/speech_server/vc.py new file mode 100644 index 000000000..ff401fb4d --- /dev/null +++ b/demos/speech_web/speech_server/vc.py @@ -0,0 +1,543 @@ +# todo: +# 1. 开启服务 +# 2. 接收录音音频,返回识别结果 +# 3. 接收ASR识别结果,返回NLP对话结果 +# 4. 接收NLP对话结果,返回TTS音频 + +import base64 +import yaml +import os +import json +import datetime +import librosa +import soundfile as sf +import numpy as np +import argparse +import uvicorn +import aiofiles +from typing import Optional, List +from pydantic import BaseModel +from fastapi import FastAPI, Header, File, UploadFile, Form, Cookie, WebSocket, WebSocketDisconnect +from fastapi.responses import StreamingResponse +from starlette.responses import FileResponse +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.websockets import WebSocketState as WebSocketState + +from src.util import * +from src.ge2e_clone import VoiceCloneGE2E +from src.tdnn_clone import VoiceCloneTDNN +from src.ernie_sat import SAT +from src.finetune import FineTune + +from paddlespeech.server.engine.asr.online.python.asr_engine import PaddleASRConnectionHanddler +from paddlespeech.server.utils.audio_process import float2pcm + + +# 解析配置 +parser = argparse.ArgumentParser( + prog='PaddleSpeechDemo', add_help=True) + +parser.add_argument( + "--port", + action="store", + type=int, + help="port of the app", + default=8010, + required=False) + +args = parser.parse_args() +port = args.port + +# 这里会对finetune产生影响,所以finetune使用了cmd +vc_model = VoiceCloneGE2E() +vc_model_tdnn = VoiceCloneTDNN() + +# if you use mfa v1 +# sat_model = SAT(mfa_version='v1') +# ft_model = FineTune(mfa_version='v1') +sat_model = SAT(mfa_version='v2') +ft_model = FineTune(mfa_version='v2') + +# 配置文件 +tts_config = "conf/tts_online_application.yaml" +asr_config = "conf/ws_conformer_wenetspeech_application_faster.yaml" +asr_init_path = "source/demo/demo.wav" +db_path = "source/db/vc.sqlite" +ie_model_path = "source/model" + +# 路径配置 +VC_UPLOAD_PATH = "source/wav/vc/upload" +VC_OUT_PATH = "source/wav/vc/out" + +FT_UPLOAD_PATH = "source/wav/finetune/upload" +FT_OUT_PATH = "source/wav/finetune/out" +FT_LABEL_PATH = "source/wav/finetune/label.json" +FT_LABEL_TXT_PATH = "source/wav/finetune/labels.txt" +FT_DEFAULT_PATH = "source/wav/finetune/default" +FT_EXP_BASE_PATH = "tmp_dir/finetune" + +SAT_UPLOAD_PATH = "source/wav/SAT/upload" +SAT_OUT_PATH = "source/wav/SAT/out" +SAT_LABEL_PATH = "source/wav/SAT/label.json" + + + +# SAT 标注结果初始化 +if os.path.exists(SAT_LABEL_PATH): + with open(SAT_LABEL_PATH, "r", encoding='utf8') as f: + sat_label_dic = json.load(f) +else: + sat_label_dic = {} + +# ft 标注结果初始化 +if os.path.exists(FT_LABEL_PATH): + with open(FT_LABEL_PATH, "r", encoding='utf8') as f: + ft_label_dic = json.load(f) +else: + ft_label_dic = {} + + +# 新建文件夹 +base_sources = [ + VC_UPLOAD_PATH, VC_OUT_PATH, + FT_UPLOAD_PATH, FT_OUT_PATH, FT_DEFAULT_PATH, + SAT_UPLOAD_PATH, SAT_OUT_PATH, +] +for path in base_sources: + os.makedirs(path, exist_ok=True) + +######## 测试一下 finetune ############### +# ft_model = FineTune(mfa_version='v2') +# data_path = FT_DEFAULT_PATH +# exp_dir = os.path.join(FT_EXP_BASE_PATH) +# ft_model.finetune(input_dir=os.path.realpath(data_path), exp_dir=os.path.realpath(exp_dir)) + + + +##################################################################### +########################### APP初始化 ############################### +##################################################################### +app = FastAPI() + +###################################################################### +########################### 接口类型 ################################# +##################################################################### + +# 接口结构 +class VcBase(BaseModel): + wavName: str + wavPath: str + +class VcBaseText(BaseModel): + wavName: str + wavPath: str + text: str + func: str + +class VcBaseSAT(BaseModel): + old_str : str + new_str : str + language : str + function: str + wav: str # base64编码 + filename: str + +class FTPath(BaseModel): + dataPath: str + +class VcBaseFT(BaseModel): + wav: str # base64编码 + filename: str + wav_path: str + +class VcBaseFTModel(BaseModel): + wav_path: str + +class VcBaseFTSyn(BaseModel): + exp_path: str + text: str + +###################################################################### +########################### 文件列表查询与保存服务 ################################# +##################################################################### + +def getVCList(path): + VC_FileDict = [] + # 查询upload路径下的wav文件名 + for root, dirs, files in os.walk(path, topdown=False): + for name in files: + # print(os.path.join(root, name)) + VC_FileDict.append( + { + 'name': name, + 'path': os.path.join(root, name) + } + ) + VC_FileDict = sorted(VC_FileDict, key=lambda x:x['name'], reverse=True) + return VC_FileDict + +async def saveFiles(files, SavePath): + right = 0 + error = 0 + error_info = "错误文件:" + for file in files: + try: + if 'blob' in file.filename: + out_file_path = os.path.join(SavePath, datetime.datetime.strftime(datetime.datetime.now(), '%H%M') + randName(3) + ".wav") + else: + out_file_path = os.path.join(SavePath, file.filename) + + print("上传文件名:", out_file_path) + async with aiofiles.open(out_file_path, 'wb') as out_file: + content = await file.read() # async read + await out_file.write(content) # async write + # 将文件转成24k, 16bit类型的wav文件 + wav, sr = librosa.load(out_file_path, sr=16000) + sf.write(out_file_path, data=wav, samplerate=sr) + right += 1 + except Exception as e: + error += 1 + error_info = error_info + file.filename + " " + str(e) + "\n" + continue + return f"上传成功:{right}, 上传失败:{error}, 失败原因: {error_info}" + +# 音频下载 +@app.post("/vc/download") +async def VcDownload(base:VcBase): + if os.path.exists(base.wavPath): + return FileResponse(base.wavPath) + else: + return ErrorRequest(message="下载请求失败,文件不存在") + +# 音频下载base64 +@app.post("/vc/download_base64") +async def VcDownloadBase64(base:VcBase): + if os.path.exists(base.wavPath): + # 将文件转成16k, 16bit类型的wav文件 + wav, sr = librosa.load(base.wavPath, sr=16000) + wav = float2pcm(wav) # float32 to int16 + wav_bytes = wav.tobytes() # to bytes + wav_base64 = base64.b64encode(wav_bytes).decode('utf8') + return SuccessRequest(result=wav_base64) + else: + return ErrorRequest(message="播放请求失败,文件不存在") + +###################################################################### +########################### VC 服务 ################################# +##################################################################### + +# 上传文件 +@app.post("/vc/upload") +async def VcUpload(files: List[UploadFile]): + # res = saveFiles(files, VC_UPLOAD_PATH) + right = 0 + error = 0 + error_info = "错误文件:" + for file in files: + try: + if 'blob' in file.filename: + out_file_path = os.path.join(VC_UPLOAD_PATH, datetime.datetime.strftime(datetime.datetime.now(), '%H%M') + randName(3) + ".wav") + else: + out_file_path = os.path.join(VC_UPLOAD_PATH, file.filename) + + print("上传文件名:", out_file_path) + async with aiofiles.open(out_file_path, 'wb') as out_file: + content = await file.read() # async read + await out_file.write(content) # async write + # 将文件转成24k, 16bit类型的wav文件 + wav, sr = librosa.load(out_file_path, sr=16000) + sf.write(out_file_path, data=wav, samplerate=sr) + right += 1 + except Exception as e: + error += 1 + error_info = error_info + file.filename + " " + str(e) + "\n" + continue + return SuccessRequest(result=f"上传成功:{right}, 上传失败:{error}, 失败原因: {error_info}") + + + +# 获取文件列表 +@app.get("/vc/list") +async def VcList(): + res = getVCList(VC_UPLOAD_PATH) + return SuccessRequest(result=res) + +# 获取音频文件 +@app.post("/vc/file") +async def VcFileGet(base:VcBase): + if os.path.exists(base.wavPath): + return FileResponse(base.wavPath) + else: + return ErrorRequest(result="获取文件失败") + +# 删除音频文件 +@app.post("/vc/del") +async def VcFileDel(base:VcBase): + if os.path.exists(base.wavPath): + os.remove(base.wavPath) + return SuccessRequest(result="删除成功") + else: + return ErrorRequest(result="删除失败") + +# 声音克隆G2P +@app.post("/vc/clone_g2p") +async def VcCloneG2P(base:VcBaseText): + if os.path.exists(base.wavPath): + try: + if base.func == 'ge2e': + wavName = base.wavName[:-4]+"_g2p.wav" + wavPath = os.path.join(VC_OUT_PATH, wavName) + vc_model.vc(text=base.text, input_wav=base.wavPath, out_wav=wavPath) + else: + wavName = base.wavName[:-4]+"_tdnn.wav" + wavPath = os.path.join(VC_OUT_PATH, wavName) + vc_model_tdnn.vc(text=base.text, input_wav=base.wavPath, out_wav=wavPath) + res = { + "wavName": wavName, + "wavPath": wavPath + } + return SuccessRequest(result=res) + except Exception as e: + print(e) + return ErrorRequest(message=f"克隆失败,合成过程报错: {str(e)}") + else: + return ErrorRequest(message="克隆失败,音频不存在") + + +###################################################################### +########################### SAT 服务 ################################# +##################################################################### +# 声音克隆SAT +@app.post("/vc/clone_sat") +async def VcCloneSAT(base:VcBaseSAT): + # 重新整理 sat_label_dict + if base.filename not in sat_label_dic or sat_label_dic[base.filename] != base.old_str: + sat_label_dic[base.filename] = base.old_str + with open(SAT_LABEL_PATH, "w", encoding='utf8') as f: + json.dump(sat_label_dic, f, ensure_ascii=False, indent=4) + + input_file_path = base.wav + + # 选择任务 + if base.language == "zh": + # 中文 + if base.function == "synthesize": + output_file_path = os.path.join(SAT_OUT_PATH, "sat_syn_zh_" + base.filename) + # 中文克隆 + sat_result = sat_model.zh_synthesize_edit( + old_str=base.old_str, + new_str=base.new_str, + input_name=os.path.realpath(input_file_path), + output_name=os.path.realpath(output_file_path), + task_name="synthesize" + ) + elif base.function == "edit": + output_file_path = os.path.join(SAT_OUT_PATH, "sat_edit_zh_" + base.filename) + # 中文语音编辑 + sat_result = sat_model.zh_synthesize_edit( + old_str=base.old_str, + new_str=base.new_str, + input_name=os.path.realpath(input_file_path), + output_name=os.path.realpath(output_file_path), + task_name="edit" + ) + elif base.function == "crossclone": + output_file_path = os.path.join(SAT_OUT_PATH, "sat_cross_zh_" + base.filename) + # 中文跨语言 + sat_result = sat_model.crossclone( + old_str=base.old_str, + new_str=base.new_str, + input_name=os.path.realpath(input_file_path), + output_name=os.path.realpath(output_file_path), + source_lang="zh", + target_lang="en" + ) + else: + return ErrorRequest(message="请检查功能选项是否正确,仅支持:synthesize, edit, crossclone") + elif base.language == "en": + if base.function == "synthesize": + output_file_path = os.path.join(SAT_OUT_PATH, "sat_syn_zh_" + base.filename) + # 英文语音克隆 + sat_result = sat_model.en_synthesize_edit( + old_str=base.old_str, + new_str=base.new_str, + input_name=os.path.realpath(input_file_path), + output_name=os.path.realpath(output_file_path), + task_name="synthesize" + ) + elif base.function == "edit": + output_file_path = os.path.join(SAT_OUT_PATH, "sat_edit_zh_" + base.filename) + # 英文语音编辑 + sat_result = sat_model.en_synthesize_edit( + old_str=base.old_str, + new_str=base.new_str, + input_name=os.path.realpath(input_file_path), + output_name=os.path.realpath(output_file_path), + task_name="edit" + ) + elif base.function == "crossclone": + output_file_path = os.path.join(SAT_OUT_PATH, "sat_cross_zh_" + base.filename) + # 英文跨语言 + sat_result = sat_model.crossclone( + old_str=base.old_str, + new_str=base.new_str, + input_name=os.path.realpath(input_file_path), + output_name=os.path.realpath(output_file_path), + source_lang="en", + target_lang="zh" + ) + else: + return ErrorRequest(message="请检查功能选项是否正确,仅支持:synthesize, edit, crossclone") + else: + return ErrorRequest(message="请检查功能选项是否正确,仅支持中文和英文") + + if sat_result: + return SuccessRequest(result=sat_result, message="SAT合成成功") + else: + return ErrorRequest(message="SAT 合成失败,请从后台检查错误信息!") + +# SAT 文件列表 +@app.get("/sat/list") +async def SatList(): + res = [] + filelist = getVCList(SAT_UPLOAD_PATH) + for fileitem in filelist: + if fileitem['name'] in sat_label_dic: + fileitem['label'] = sat_label_dic[fileitem['name']] + else: + fileitem['label'] = "" + res.append(fileitem) + return SuccessRequest(result=res) + +# 上传 SAT 音频 +# 上传文件 +@app.post("/sat/upload") +async def SATUpload(files: List[UploadFile]): + right = 0 + error = 0 + error_info = "错误文件:" + for file in files: + try: + if 'blob' in file.filename: + out_file_path = os.path.join(SAT_UPLOAD_PATH, datetime.datetime.strftime(datetime.datetime.now(), '%H%M') + randName(3) + ".wav") + else: + out_file_path = os.path.join(SAT_UPLOAD_PATH, file.filename) + + print("上传文件名:", out_file_path) + async with aiofiles.open(out_file_path, 'wb') as out_file: + content = await file.read() # async read + await out_file.write(content) # async write + # 将文件转成24k, 16bit类型的wav文件 + wav, sr = librosa.load(out_file_path, sr=16000) + sf.write(out_file_path, data=wav, samplerate=sr) + right += 1 + except Exception as e: + error += 1 + error_info = error_info + file.filename + " " + str(e) + "\n" + continue + return SuccessRequest(result=f"上传成功:{right}, 上传失败:{error}, 失败原因: {error_info}") + + +###################################################################### +########################### FinueTune 服务 ################################# +##################################################################### + +# finetune 文件列表 +@app.post("/finetune/list") +async def FineTuneList(Path:FTPath): + dataPath = Path.dataPath + if dataPath == "default": + # 默认路径 + FT_PATH = FT_DEFAULT_PATH + else: + FT_PATH = dataPath + + res = [] + filelist = getVCList(FT_PATH) + for name, value in ft_label_dic.items(): + wav_path = os.path.join(FT_PATH, name) + if not os.path.exists(wav_path): + wav_path = "" + d = { + 'text': value['text'], + 'name': name, + 'path': wav_path + } + res.append(d) + return SuccessRequest(result=res) + +# 一键重置,获取新的文件地址 +@app.get('/finetune/newdir') +async def FTGetNewDir(): + new_path = os.path.join(FT_UPLOAD_PATH, randName(3)) + if not os.path.exists(new_path): + os.makedirs(new_path, exist_ok=True) + # 把 labels.txt 复制进去 + cmd = f"cp {FT_LABEL_TXT_PATH} {new_path}" + os.system(cmd) + return SuccessRequest(result=new_path) + + +# finetune 上传文件 +@app.post("/finetune/upload") +async def FTUpload(base:VcBaseFT): + try: + # 文件夹是否存在 + if not os.path.exists(base.wav_path): + os.makedirs(base.wav_path) + # 保存音频文件 + out_file_path = os.path.join(base.wav_path, base.filename) + wav_b = base64.b64decode(base.wav) + async with aiofiles.open(out_file_path, 'wb') as out_file: + await out_file.write(wav_b) # async write + + return SuccessRequest(result=f"上传成功") + except Exception as e: + return ErrorRequest(result=f"上传失败") + +# finetune 微调 +@app.post("/finetune/clone_finetune") +async def FTModel(base:VcBaseFTModel): + # 先检查 wav_path 是否有效 + if base.wav_path == 'default': + data_path = FT_DEFAULT_PATH + else: + data_path = base.wav_path + if not os.path.exists(data_path): + return ErrorRequest(message=f"数据文件夹不存在") + + data_base = data_path.split(os.sep)[-1] + exp_dir = os.path.join(FT_EXP_BASE_PATH, data_base) + try: + exp_dir = ft_model.finetune(input_dir=os.path.realpath(data_path), exp_dir=os.path.realpath(exp_dir)) + if exp_dir: + return SuccessRequest(result=exp_dir) + else: + return ErrorRequest(message=f"微调失败") + except Exception as e: + print(e) + return ErrorRequest(message=f"微调失败") + + +# finetune 合成 +@app.post("/finetune/clone_finetune_syn") +async def FTSyn(base:VcBaseFTSyn): + try: + if not os.path.exists(base.exp_path): + return ErrorRequest(result=f"模型路径不存在") + wav_name = randName(5) + wav_path = ft_model.synthesize(text=base.text, wav_name=wav_name, out_wav_dir=os.path.realpath(FT_OUT_PATH), exp_dir = os.path.realpath(base.exp_path)) + if wav_path: + res = { + "wavName": wav_name+".wav", + "wavPath": wav_path + } + return SuccessRequest(result=res) + else: + return ErrorRequest(message="音频合成失败") + except Exception as e: + return ErrorRequest(message="音频合成失败") + +if __name__ == '__main__': + uvicorn.run(app=app, host='0.0.0.0', port=port) diff --git a/demos/speech_web/web_client/package.json b/demos/speech_web/web_client/package.json index 7f28d4c97..e3701608b 100644 --- a/demos/speech_web/web_client/package.json +++ b/demos/speech_web/web_client/package.json @@ -8,6 +8,7 @@ "preview": "vite preview" }, "dependencies": { + "@element-plus/icons-vue": "^2.0.9", "ant-design-vue": "^2.2.8", "axios": "^0.26.1", "element-plus": "^2.1.9", diff --git a/demos/speech_web/web_client/src/api/API.js b/demos/speech_web/web_client/src/api/API.js index 0feaa63f1..5adca3622 100644 --- a/demos/speech_web/web_client/src/api/API.js +++ b/demos/speech_web/web_client/src/api/API.js @@ -19,6 +19,26 @@ export const apiURL = { CHAT_SOCKET_RECORD: 'ws://localhost:8010/ws/asr/offlineStream', // ChatBot websocket 接口 ASR_SOCKET_RECORD: 'ws://localhost:8010/ws/asr/onlineStream', // Stream ASR 接口 TTS_SOCKET_RECORD: 'ws://localhost:8010/ws/tts/online', // Stream TTS 接口 + + // voice clone + // Voice Clone + VC_List: '/api/vc/list', + SAT_List: '/api/sat/list', + FineTune_List: '/api/finetune/list', + + VC_Upload: '/api/vc/upload', + SAT_Upload: '/api/sat/upload', + FineTune_Upload: '/api/finetune/upload', + FineTune_NewDir: '/api/finetune/newdir', + + VC_Download: '/api/vc/download', + VC_Download_Base64: '/api/vc/download_base64', + VC_Del: '/api/vc/del', + + VC_CloneG2p: '/api/vc/clone_g2p', + VC_CloneSAT: '/api/vc/clone_sat', + VC_CloneFineTune: '/api/finetune/clone_finetune', + VC_CloneFineTuneSyn: '/api/finetune/clone_finetune_syn', } diff --git a/demos/speech_web/web_client/src/api/ApiVC.js b/demos/speech_web/web_client/src/api/ApiVC.js new file mode 100644 index 000000000..0dc0f6834 --- /dev/null +++ b/demos/speech_web/web_client/src/api/ApiVC.js @@ -0,0 +1,88 @@ +import axios from 'axios' +import {apiURL} from "./API.js" + +// 上传音频-vc +export async function vcUpload(params){ + const result = await axios.post(apiURL.VC_Upload, params); + return result +} + +// 上传音频-sat +export async function satUpload(params){ + const result = await axios.post(apiURL.SAT_Upload, params); + return result +} + +// 上传音频-finetune +export async function fineTuneUpload(params){ + const result = await axios.post(apiURL.FineTune_Upload, params); + return result +} + +// 删除音频 +export async function vcDel(params){ + const result = await axios.post(apiURL.VC_Del, params); + return result +} + +// 获取音频列表vc +export async function vcList(){ + const result = await axios.get(apiURL.VC_List); + return result +} +// 获取音频列表Sat +export async function satList(){ + const result = await axios.get(apiURL.SAT_List); + return result +} + +// 获取音频列表fineTune +export async function fineTuneList(params){ + const result = await axios.post(apiURL.FineTune_List, params); + return result +} + +// fineTune 一键重置 获取新的文件夹 +export async function fineTuneNewDir(){ + const result = await axios.get(apiURL.FineTune_NewDir); + return result +} + +// 获取音频数据 +export async function vcDownload(params){ + const result = await axios.post(apiURL.VC_Download, params); + return result +} + +// 获取音频数据Base64 +export async function vcDownloadBase64(params){ + const result = await axios.post(apiURL.VC_Download_Base64, params); + return result +} + + +// 克隆合成G2P +export async function vcCloneG2P(params){ + const result = await axios.post(apiURL.VC_CloneG2p, params); + return result +} + +// 克隆合成SAT +export async function vcCloneSAT(params){ + const result = await axios.post(apiURL.VC_CloneSAT, params); + return result +} + +// 克隆合成 - finetune 微调 +export async function vcCloneFineTune(params){ + const result = await axios.post(apiURL.VC_CloneFineTune, params); + return result +} + +// 克隆合成 - finetune 合成 +export async function vcCloneFineTuneSyn(params){ + const result = await axios.post(apiURL.VC_CloneFineTuneSyn, params); + return result +} + + diff --git a/demos/speech_web/web_client/src/components/Content/Header/Header.vue b/demos/speech_web/web_client/src/components/Content/Header/Header.vue index 8135a2bff..c20f3366e 100644 --- a/demos/speech_web/web_client/src/components/Content/Header/Header.vue +++ b/demos/speech_web/web_client/src/components/Content/Header/Header.vue @@ -4,7 +4,7 @@ 飞桨-PaddleSpeech
- PaddleSpeech 是基于飞桨 PaddlePaddle 的语音方向的开源模型库,用于语音和音频中的各种关键任务的开发,欢迎大家Star收藏鼓励 + PaddleSpeech 是基于飞桨 PaddlePaddle 的语音方向的开源模型库,用于语音和音频中的各种关键任务的开发。支持语音识别,语音合成,声纹识别,声音分类,语音唤醒,语音翻译等多种语音任务,荣获 NAACL2022 Best Demo Award 。如果你喜欢这个示例,欢迎在 github 中 star 收藏鼓励。
diff --git a/demos/speech_web/web_client/src/components/Content/Header/style.less b/demos/speech_web/web_client/src/components/Content/Header/style.less index 9d0261378..cc97c741e 100644 --- a/demos/speech_web/web_client/src/components/Content/Header/style.less +++ b/demos/speech_web/web_client/src/components/Content/Header/style.less @@ -43,6 +43,7 @@ margin-bottom: 40px; display: flex; align-items: center; + margin-top: 40px; }; .speech_header_link { display: block; diff --git a/demos/speech_web/web_client/src/components/Experience.vue b/demos/speech_web/web_client/src/components/Experience.vue index 5620d6af9..4f32faf95 100644 --- a/demos/speech_web/web_client/src/components/Experience.vue +++ b/demos/speech_web/web_client/src/components/Experience.vue @@ -6,6 +6,10 @@ import TTST from './SubMenu/TTS/TTST.vue' import VPRT from './SubMenu/VPR/VPRT.vue' import IET from './SubMenu/IE/IET.vue' +import VoiceCloneT from './SubMenu/VoiceClone/VoiceClone.vue' +import ENIRE_SATT from './SubMenu/ENIRE_SAT/ENIRE_SAT.vue' +import FineTuneT from './SubMenu/FineTune/FineTune.vue' +