clean old ernie sat inference scripts (#2316)
parent
d21e03c03e
commit
7b864e8f38
Before Width: | Height: | Size: 140 KiB |
@ -1,609 +0,0 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import random
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
import soundfile as sf
|
||||
from align import alignment
|
||||
from align import alignment_zh
|
||||
from align import words2phns
|
||||
from align import words2phns_zh
|
||||
from paddle import nn
|
||||
from sedit_arg_parser import parse_args
|
||||
from utils import eval_durs
|
||||
from utils import get_voc_out
|
||||
from utils import is_chinese
|
||||
from utils import load_num_sequence_text
|
||||
from utils import read_2col_text
|
||||
|
||||
from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn
|
||||
from paddlespeech.t2s.models.ernie_sat.mlm import build_model_from_file
|
||||
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
def get_wav(wav_path: str,
|
||||
source_lang: str='english',
|
||||
target_lang: str='english',
|
||||
model_name: str="paddle_checkpoint_en",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
non_autoreg: bool=True):
|
||||
wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
model_name=model_name,
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
use_teacher_forcing=non_autoreg)
|
||||
|
||||
masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
|
||||
|
||||
alt_wav = get_voc_out(masked_feat)
|
||||
|
||||
old_time_bdy = [hop_length * x for x in old_span_bdy]
|
||||
|
||||
wav_replaced = np.concatenate(
|
||||
[wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]])
|
||||
|
||||
data_dict = {"origin": wav_org, "output": wav_replaced}
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
def load_model(model_name: str="paddle_checkpoint_en"):
|
||||
config_path = './pretrained_model/{}/config.yaml'.format(model_name)
|
||||
model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
|
||||
mlm_model, conf = build_model_from_file(
|
||||
config_file=config_path, model_file=model_path)
|
||||
return mlm_model, conf
|
||||
|
||||
|
||||
def read_data(uid: str, prefix: os.PathLike):
|
||||
# 获取 uid 对应的文本
|
||||
mfa_text = read_2col_text(prefix + '/text')[uid]
|
||||
# 获取 uid 对应的音频路径
|
||||
mfa_wav_path = read_2col_text(prefix + '/wav.scp')[uid]
|
||||
if not os.path.isabs(mfa_wav_path):
|
||||
mfa_wav_path = prefix + mfa_wav_path
|
||||
return mfa_text, mfa_wav_path
|
||||
|
||||
|
||||
def get_align_data(uid: str, prefix: os.PathLike):
|
||||
mfa_path = prefix + "mfa_"
|
||||
mfa_text = read_2col_text(mfa_path + 'text')[uid]
|
||||
mfa_start = load_num_sequence_text(
|
||||
mfa_path + 'start', loader_type='text_float')[uid]
|
||||
mfa_end = load_num_sequence_text(
|
||||
mfa_path + 'end', loader_type='text_float')[uid]
|
||||
mfa_wav_path = read_2col_text(mfa_path + 'wav.scp')[uid]
|
||||
return mfa_text, mfa_start, mfa_end, mfa_wav_path
|
||||
|
||||
|
||||
# 获取需要被 mask 的 mel 帧的范围
|
||||
def get_masked_mel_bdy(mfa_start: List[float],
|
||||
mfa_end: List[float],
|
||||
fs: int,
|
||||
hop_length: int,
|
||||
span_to_repl: List[List[int]]):
|
||||
align_start = np.array(mfa_start)
|
||||
align_end = np.array(mfa_end)
|
||||
align_start = np.floor(fs * align_start / hop_length).astype('int')
|
||||
align_end = np.floor(fs * align_end / hop_length).astype('int')
|
||||
if span_to_repl[0] >= len(mfa_start):
|
||||
span_bdy = [align_end[-1], align_end[-1]]
|
||||
else:
|
||||
span_bdy = [
|
||||
align_start[span_to_repl[0]], align_end[span_to_repl[1] - 1]
|
||||
]
|
||||
return span_bdy, align_start, align_end
|
||||
|
||||
|
||||
def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
|
||||
dic = {}
|
||||
keys_to_del = []
|
||||
exist_idx = []
|
||||
sp_count = 0
|
||||
add_sp_count = 0
|
||||
for key in word2phns.keys():
|
||||
idx, wrd = key.split('_')
|
||||
if wrd == 'sp':
|
||||
sp_count += 1
|
||||
exist_idx.append(int(idx))
|
||||
else:
|
||||
keys_to_del.append(key)
|
||||
|
||||
for key in keys_to_del:
|
||||
del word2phns[key]
|
||||
|
||||
cur_id = 0
|
||||
for key in tp_word2phns.keys():
|
||||
if cur_id in exist_idx:
|
||||
dic[str(cur_id) + "_sp"] = 'sp'
|
||||
cur_id += 1
|
||||
add_sp_count += 1
|
||||
idx, wrd = key.split('_')
|
||||
dic[str(cur_id) + "_" + wrd] = tp_word2phns[key]
|
||||
cur_id += 1
|
||||
|
||||
if add_sp_count + 1 == sp_count:
|
||||
dic[str(cur_id) + "_sp"] = 'sp'
|
||||
add_sp_count += 1
|
||||
|
||||
assert add_sp_count == sp_count, "sp are not added in dic"
|
||||
return dic
|
||||
|
||||
|
||||
def get_max_idx(dic):
|
||||
return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1]
|
||||
|
||||
|
||||
def get_phns_and_spans(wav_path: str,
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
source_lang: str="english",
|
||||
target_lang: str="english"):
|
||||
is_append = (old_str == new_str[:len(old_str)])
|
||||
old_phns, mfa_start, mfa_end = [], [], []
|
||||
# source
|
||||
if source_lang == "english":
|
||||
intervals, word2phns = alignment(wav_path, old_str)
|
||||
elif source_lang == "chinese":
|
||||
intervals, word2phns = alignment_zh(wav_path, old_str)
|
||||
_, tp_word2phns = words2phns_zh(old_str)
|
||||
|
||||
for key, value in tp_word2phns.items():
|
||||
idx, wrd = key.split('_')
|
||||
cur_val = " ".join(value)
|
||||
tp_word2phns[key] = cur_val
|
||||
|
||||
word2phns = recover_dict(word2phns, tp_word2phns)
|
||||
else:
|
||||
assert source_lang == "chinese" or source_lang == "english", \
|
||||
"source_lang is wrong..."
|
||||
|
||||
for item in intervals:
|
||||
old_phns.append(item[0])
|
||||
mfa_start.append(float(item[1]))
|
||||
mfa_end.append(float(item[2]))
|
||||
# target
|
||||
if is_append and (source_lang != target_lang):
|
||||
cross_lingual_clone = True
|
||||
else:
|
||||
cross_lingual_clone = False
|
||||
|
||||
if cross_lingual_clone:
|
||||
str_origin = new_str[:len(old_str)]
|
||||
str_append = new_str[len(old_str):]
|
||||
|
||||
if target_lang == "chinese":
|
||||
phns_origin, origin_word2phns = words2phns(str_origin)
|
||||
phns_append, append_word2phns_tmp = words2phns_zh(str_append)
|
||||
|
||||
elif target_lang == "english":
|
||||
# 原始句子
|
||||
phns_origin, origin_word2phns = words2phns_zh(str_origin)
|
||||
# clone 句子
|
||||
phns_append, append_word2phns_tmp = words2phns(str_append)
|
||||
else:
|
||||
assert target_lang == "chinese" or target_lang == "english", \
|
||||
"cloning is not support for this language, please check it."
|
||||
|
||||
new_phns = phns_origin + phns_append
|
||||
|
||||
append_word2phns = {}
|
||||
length = len(origin_word2phns)
|
||||
for key, value in append_word2phns_tmp.items():
|
||||
idx, wrd = key.split('_')
|
||||
append_word2phns[str(int(idx) + length) + '_' + wrd] = value
|
||||
new_word2phns = origin_word2phns.copy()
|
||||
new_word2phns.update(append_word2phns)
|
||||
|
||||
else:
|
||||
if source_lang == target_lang and target_lang == "english":
|
||||
new_phns, new_word2phns = words2phns(new_str)
|
||||
elif source_lang == target_lang and target_lang == "chinese":
|
||||
new_phns, new_word2phns = words2phns_zh(new_str)
|
||||
else:
|
||||
assert source_lang == target_lang, \
|
||||
"source language is not same with target language..."
|
||||
|
||||
span_to_repl = [0, len(old_phns) - 1]
|
||||
span_to_add = [0, len(new_phns) - 1]
|
||||
left_idx = 0
|
||||
new_phns_left = []
|
||||
sp_count = 0
|
||||
# find the left different index
|
||||
for key in word2phns.keys():
|
||||
idx, wrd = key.split('_')
|
||||
if wrd == 'sp':
|
||||
sp_count += 1
|
||||
new_phns_left.append('sp')
|
||||
else:
|
||||
idx = str(int(idx) - sp_count)
|
||||
if idx + '_' + wrd in new_word2phns:
|
||||
left_idx += len(new_word2phns[idx + '_' + wrd])
|
||||
new_phns_left.extend(word2phns[key].split())
|
||||
else:
|
||||
span_to_repl[0] = len(new_phns_left)
|
||||
span_to_add[0] = len(new_phns_left)
|
||||
break
|
||||
|
||||
# reverse word2phns and new_word2phns
|
||||
right_idx = 0
|
||||
new_phns_right = []
|
||||
sp_count = 0
|
||||
word2phns_max_idx = get_max_idx(word2phns)
|
||||
new_word2phns_max_idx = get_max_idx(new_word2phns)
|
||||
new_phns_mid = []
|
||||
if is_append:
|
||||
new_phns_right = []
|
||||
new_phns_mid = new_phns[left_idx:]
|
||||
span_to_repl[0] = len(new_phns_left)
|
||||
span_to_add[0] = len(new_phns_left)
|
||||
span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
|
||||
span_to_repl[1] = len(old_phns) - len(new_phns_right)
|
||||
# speech edit
|
||||
else:
|
||||
for key in list(word2phns.keys())[::-1]:
|
||||
idx, wrd = key.split('_')
|
||||
if wrd == 'sp':
|
||||
sp_count += 1
|
||||
new_phns_right = ['sp'] + new_phns_right
|
||||
else:
|
||||
idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx)
|
||||
- sp_count))
|
||||
if idx + '_' + wrd in new_word2phns:
|
||||
right_idx -= len(new_word2phns[idx + '_' + wrd])
|
||||
new_phns_right = word2phns[key].split() + new_phns_right
|
||||
else:
|
||||
span_to_repl[1] = len(old_phns) - len(new_phns_right)
|
||||
new_phns_mid = new_phns[left_idx:right_idx]
|
||||
span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
|
||||
if len(new_phns_mid) == 0:
|
||||
span_to_add[1] = min(span_to_add[1] + 1, len(new_phns))
|
||||
span_to_add[0] = max(0, span_to_add[0] - 1)
|
||||
span_to_repl[0] = max(0, span_to_repl[0] - 1)
|
||||
span_to_repl[1] = min(span_to_repl[1] + 1,
|
||||
len(old_phns))
|
||||
break
|
||||
new_phns = new_phns_left + new_phns_mid + new_phns_right
|
||||
'''
|
||||
For that reason cover should not be given.
|
||||
For that reason cover is impossible to be given.
|
||||
span_to_repl: [17, 23] "should not"
|
||||
span_to_add: [17, 30] "is impossible to"
|
||||
'''
|
||||
return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add
|
||||
|
||||
|
||||
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
|
||||
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
|
||||
def get_dur_adj_factor(orig_dur: List[int],
|
||||
pred_dur: List[int],
|
||||
phns: List[str]):
|
||||
length = 0
|
||||
factor_list = []
|
||||
for orig, pred, phn in zip(orig_dur, pred_dur, phns):
|
||||
if pred == 0 or phn == 'sp':
|
||||
continue
|
||||
else:
|
||||
factor_list.append(orig / pred)
|
||||
factor_list = np.array(factor_list)
|
||||
factor_list.sort()
|
||||
if len(factor_list) < 5:
|
||||
return 1
|
||||
length = 2
|
||||
avg = np.average(factor_list[length:-length])
|
||||
return avg
|
||||
|
||||
|
||||
def prep_feats_with_dur(wav_path: str,
|
||||
source_lang: str="English",
|
||||
target_lang: str="English",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
mask_reconstruct: bool=False,
|
||||
duration_adjust: bool=True,
|
||||
start_end_sp: bool=False,
|
||||
fs: int=24000,
|
||||
hop_length: int=300):
|
||||
'''
|
||||
Returns:
|
||||
np.ndarray: new wav, replace the part to be edited in original wav with 0
|
||||
List[str]: new phones
|
||||
List[float]: mfa start of new wav
|
||||
List[float]: mfa end of new wav
|
||||
List[int]: masked mel boundary of original wav
|
||||
List[int]: masked mel boundary of new wav
|
||||
'''
|
||||
wav_org, _ = librosa.load(wav_path, sr=fs)
|
||||
|
||||
mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans(
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang)
|
||||
|
||||
if start_end_sp:
|
||||
if new_phns[-1] != 'sp':
|
||||
new_phns = new_phns + ['sp']
|
||||
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
|
||||
if target_lang == "english" or target_lang == "chinese":
|
||||
old_durs = eval_durs(old_phns, target_lang=source_lang)
|
||||
else:
|
||||
assert target_lang == "chinese" or target_lang == "english", \
|
||||
"calculate duration_predict is not support for this language..."
|
||||
|
||||
orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)]
|
||||
if '[MASK]' in new_str:
|
||||
new_phns = old_phns
|
||||
span_to_add = span_to_repl
|
||||
d_factor_left = get_dur_adj_factor(
|
||||
orig_dur=orig_old_durs[:span_to_repl[0]],
|
||||
pred_dur=old_durs[:span_to_repl[0]],
|
||||
phns=old_phns[:span_to_repl[0]])
|
||||
d_factor_right = get_dur_adj_factor(
|
||||
orig_dur=orig_old_durs[span_to_repl[1]:],
|
||||
pred_dur=old_durs[span_to_repl[1]:],
|
||||
phns=old_phns[span_to_repl[1]:])
|
||||
d_factor = (d_factor_left + d_factor_right) / 2
|
||||
new_durs_adjusted = [d_factor * i for i in old_durs]
|
||||
else:
|
||||
if duration_adjust:
|
||||
d_factor = get_dur_adj_factor(
|
||||
orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
|
||||
d_factor = d_factor * 1.25
|
||||
else:
|
||||
d_factor = 1
|
||||
|
||||
if target_lang == "english" or target_lang == "chinese":
|
||||
new_durs = eval_durs(new_phns, target_lang=target_lang)
|
||||
else:
|
||||
assert target_lang == "chinese" or target_lang == "english", \
|
||||
"calculate duration_predict is not support for this language..."
|
||||
|
||||
new_durs_adjusted = [d_factor * i for i in new_durs]
|
||||
|
||||
new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]])
|
||||
old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]])
|
||||
dur_offset = new_span_dur_sum - old_span_dur_sum
|
||||
new_mfa_start = mfa_start[:span_to_repl[0]]
|
||||
new_mfa_end = mfa_end[:span_to_repl[0]]
|
||||
for i in new_durs_adjusted[span_to_add[0]:span_to_add[1]]:
|
||||
if len(new_mfa_end) == 0:
|
||||
new_mfa_start.append(0)
|
||||
new_mfa_end.append(i)
|
||||
else:
|
||||
new_mfa_start.append(new_mfa_end[-1])
|
||||
new_mfa_end.append(new_mfa_end[-1] + i)
|
||||
new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]]
|
||||
new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]]
|
||||
|
||||
# 3. get new wav
|
||||
# 在原始句子后拼接
|
||||
if span_to_repl[0] >= len(mfa_start):
|
||||
left_idx = len(wav_org)
|
||||
right_idx = left_idx
|
||||
# 在原始句子中间替换
|
||||
else:
|
||||
left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs))
|
||||
right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs))
|
||||
blank_wav = np.zeros(
|
||||
(int(np.ceil(new_span_dur_sum * fs)), ), dtype=wav_org.dtype)
|
||||
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
|
||||
new_wav = np.concatenate(
|
||||
[wav_org[:left_idx], blank_wav, wav_org[right_idx:]])
|
||||
|
||||
# 4. get old and new mel span to be mask
|
||||
# [92, 92]
|
||||
|
||||
old_span_bdy, mfa_start, mfa_end = get_masked_mel_bdy(
|
||||
mfa_start=mfa_start,
|
||||
mfa_end=mfa_end,
|
||||
fs=fs,
|
||||
hop_length=hop_length,
|
||||
span_to_repl=span_to_repl)
|
||||
# [92, 174]
|
||||
# new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别
|
||||
new_span_bdy, new_mfa_start, new_mfa_end = get_masked_mel_bdy(
|
||||
mfa_start=new_mfa_start,
|
||||
mfa_end=new_mfa_end,
|
||||
fs=fs,
|
||||
hop_length=hop_length,
|
||||
span_to_repl=span_to_add)
|
||||
|
||||
# old_span_bdy, new_span_bdy 是帧级别的范围
|
||||
return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy
|
||||
|
||||
|
||||
def prep_feats(wav_path: str,
|
||||
source_lang: str="english",
|
||||
target_lang: str="english",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
duration_adjust: bool=True,
|
||||
start_end_sp: bool=False,
|
||||
mask_reconstruct: bool=False,
|
||||
fs: int=24000,
|
||||
hop_length: int=300,
|
||||
token_list: List[str]=[]):
|
||||
wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
wav_path=wav_path,
|
||||
duration_adjust=duration_adjust,
|
||||
start_end_sp=start_end_sp,
|
||||
mask_reconstruct=mask_reconstruct,
|
||||
fs=fs,
|
||||
hop_length=hop_length)
|
||||
|
||||
token_to_id = {item: i for i, item in enumerate(token_list)}
|
||||
text = np.array(
|
||||
list(map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns)))
|
||||
span_bdy = np.array(new_span_bdy)
|
||||
|
||||
batch = [('1', {
|
||||
"speech": wav,
|
||||
"align_start": mfa_start,
|
||||
"align_end": mfa_end,
|
||||
"text": text,
|
||||
"span_bdy": span_bdy
|
||||
})]
|
||||
|
||||
return batch, old_span_bdy, new_span_bdy
|
||||
|
||||
|
||||
def decode_with_model(mlm_model: nn.Layer,
|
||||
collate_fn,
|
||||
wav_path: str,
|
||||
source_lang: str="english",
|
||||
target_lang: str="english",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
use_teacher_forcing: bool=False,
|
||||
duration_adjust: bool=True,
|
||||
start_end_sp: bool=False,
|
||||
fs: int=24000,
|
||||
hop_length: int=300,
|
||||
token_list: List[str]=[]):
|
||||
batch, old_span_bdy, new_span_bdy = prep_feats(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
duration_adjust=duration_adjust,
|
||||
start_end_sp=start_end_sp,
|
||||
fs=fs,
|
||||
hop_length=hop_length,
|
||||
token_list=token_list)
|
||||
|
||||
feats = collate_fn(batch)[1]
|
||||
|
||||
if 'text_masked_pos' in feats.keys():
|
||||
feats.pop('text_masked_pos')
|
||||
|
||||
output = mlm_model.inference(
|
||||
text=feats['text'],
|
||||
speech=feats['speech'],
|
||||
masked_pos=feats['masked_pos'],
|
||||
speech_mask=feats['speech_mask'],
|
||||
text_mask=feats['text_mask'],
|
||||
speech_seg_pos=feats['speech_seg_pos'],
|
||||
text_seg_pos=feats['text_seg_pos'],
|
||||
span_bdy=new_span_bdy,
|
||||
use_teacher_forcing=use_teacher_forcing)
|
||||
|
||||
# 拼接音频
|
||||
output_feat = paddle.concat(x=output, axis=0)
|
||||
wav_org, _ = librosa.load(wav_path, sr=fs)
|
||||
return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
|
||||
|
||||
|
||||
def get_mlm_output(wav_path: str,
|
||||
model_name: str="paddle_checkpoint_en",
|
||||
source_lang: str="english",
|
||||
target_lang: str="english",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
use_teacher_forcing: bool=False,
|
||||
duration_adjust: bool=True,
|
||||
start_end_sp: bool=False):
|
||||
mlm_model, train_conf = load_model(model_name)
|
||||
mlm_model.eval()
|
||||
|
||||
collate_fn = build_mlm_collate_fn(
|
||||
sr=train_conf.feats_extract_conf['fs'],
|
||||
n_fft=train_conf.feats_extract_conf['n_fft'],
|
||||
hop_length=train_conf.feats_extract_conf['hop_length'],
|
||||
win_length=train_conf.feats_extract_conf['win_length'],
|
||||
n_mels=train_conf.feats_extract_conf['n_mels'],
|
||||
fmin=train_conf.feats_extract_conf['fmin'],
|
||||
fmax=train_conf.feats_extract_conf['fmax'],
|
||||
mlm_prob=train_conf['mlm_prob'],
|
||||
mean_phn_span=train_conf['mean_phn_span'],
|
||||
seg_emb=train_conf.encoder_conf['input_layer'] == 'sega_mlm')
|
||||
|
||||
return decode_with_model(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
mlm_model=mlm_model,
|
||||
collate_fn=collate_fn,
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
use_teacher_forcing=use_teacher_forcing,
|
||||
duration_adjust=duration_adjust,
|
||||
start_end_sp=start_end_sp,
|
||||
fs=train_conf.feats_extract_conf['fs'],
|
||||
hop_length=train_conf.feats_extract_conf['hop_length'],
|
||||
token_list=train_conf.token_list)
|
||||
|
||||
|
||||
def evaluate(uid: str,
|
||||
source_lang: str="english",
|
||||
target_lang: str="english",
|
||||
prefix: os.PathLike="./prompt/dev/",
|
||||
model_name: str="paddle_checkpoint_en",
|
||||
new_str: str="",
|
||||
prompt_decoding: bool=False,
|
||||
task_name: str=None):
|
||||
|
||||
# get origin text and path of origin wav
|
||||
old_str, wav_path = read_data(uid=uid, prefix=prefix)
|
||||
|
||||
if task_name == 'edit':
|
||||
new_str = new_str
|
||||
elif task_name == 'synthesize':
|
||||
new_str = old_str + new_str
|
||||
else:
|
||||
new_str = old_str + ' '.join([ch for ch in new_str if is_chinese(ch)])
|
||||
|
||||
print('new_str is ', new_str)
|
||||
|
||||
results_dict = get_wav(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
model_name=model_name,
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str)
|
||||
return results_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parse config and args
|
||||
args = parse_args()
|
||||
|
||||
data_dict = evaluate(
|
||||
uid=args.uid,
|
||||
source_lang=args.source_lang,
|
||||
target_lang=args.target_lang,
|
||||
prefix=args.prefix,
|
||||
model_name=args.model_name,
|
||||
new_str=args.new_str,
|
||||
task_name=args.task_name)
|
||||
sf.write(args.output_name, data_dict['output'], samplerate=24000)
|
||||
print("finished...")
|
@ -1,622 +0,0 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import random
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
import soundfile as sf
|
||||
import yaml
|
||||
from align import alignment
|
||||
from align import alignment_zh
|
||||
from align import words2phns
|
||||
from align import words2phns_zh
|
||||
from paddle import nn
|
||||
from sedit_arg_parser import parse_args
|
||||
from utils import eval_durs
|
||||
from utils import get_voc_out
|
||||
from utils import is_chinese
|
||||
from utils import load_num_sequence_text
|
||||
from utils import read_2col_text
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn
|
||||
from paddlespeech.t2s.models.ernie_sat.ernie_sat import ErnieSAT
|
||||
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
def get_wav(wav_path: str,
|
||||
source_lang: str='english',
|
||||
target_lang: str='english',
|
||||
model_name: str="paddle_checkpoint_en",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
non_autoreg: bool=True):
|
||||
wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
model_name=model_name,
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
use_teacher_forcing=non_autoreg)
|
||||
|
||||
masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
|
||||
|
||||
alt_wav = get_voc_out(masked_feat)
|
||||
|
||||
old_time_bdy = [hop_length * x for x in old_span_bdy]
|
||||
|
||||
wav_replaced = np.concatenate(
|
||||
[wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]])
|
||||
|
||||
data_dict = {"origin": wav_org, "output": wav_replaced}
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
def load_model(model_name: str="paddle_checkpoint_en"):
|
||||
config_path = './pretrained_model/{}/default.yaml'.format(model_name)
|
||||
model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
|
||||
with open(config_path) as f:
|
||||
conf = CfgNode(yaml.safe_load(f))
|
||||
token_list = list(conf.token_list)
|
||||
vocab_size = len(token_list)
|
||||
odim = conf.n_mels
|
||||
mlm_model = ErnieSAT(idim=vocab_size, odim=odim, **conf["model"])
|
||||
state_dict = paddle.load(model_path)
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = "model." + key
|
||||
new_state_dict[new_key] = value
|
||||
mlm_model.set_state_dict(new_state_dict)
|
||||
mlm_model.eval()
|
||||
|
||||
return mlm_model, conf
|
||||
|
||||
|
||||
def read_data(uid: str, prefix: os.PathLike):
|
||||
# 获取 uid 对应的文本
|
||||
mfa_text = read_2col_text(prefix + '/text')[uid]
|
||||
# 获取 uid 对应的音频路径
|
||||
mfa_wav_path = read_2col_text(prefix + '/wav.scp')[uid]
|
||||
if not os.path.isabs(mfa_wav_path):
|
||||
mfa_wav_path = prefix + mfa_wav_path
|
||||
return mfa_text, mfa_wav_path
|
||||
|
||||
|
||||
def get_align_data(uid: str, prefix: os.PathLike):
|
||||
mfa_path = prefix + "mfa_"
|
||||
mfa_text = read_2col_text(mfa_path + 'text')[uid]
|
||||
mfa_start = load_num_sequence_text(
|
||||
mfa_path + 'start', loader_type='text_float')[uid]
|
||||
mfa_end = load_num_sequence_text(
|
||||
mfa_path + 'end', loader_type='text_float')[uid]
|
||||
mfa_wav_path = read_2col_text(mfa_path + 'wav.scp')[uid]
|
||||
return mfa_text, mfa_start, mfa_end, mfa_wav_path
|
||||
|
||||
|
||||
# 获取需要被 mask 的 mel 帧的范围
|
||||
def get_masked_mel_bdy(mfa_start: List[float],
|
||||
mfa_end: List[float],
|
||||
fs: int,
|
||||
hop_length: int,
|
||||
span_to_repl: List[List[int]]):
|
||||
align_start = np.array(mfa_start)
|
||||
align_end = np.array(mfa_end)
|
||||
align_start = np.floor(fs * align_start / hop_length).astype('int')
|
||||
align_end = np.floor(fs * align_end / hop_length).astype('int')
|
||||
if span_to_repl[0] >= len(mfa_start):
|
||||
span_bdy = [align_end[-1], align_end[-1]]
|
||||
else:
|
||||
span_bdy = [
|
||||
align_start[span_to_repl[0]], align_end[span_to_repl[1] - 1]
|
||||
]
|
||||
return span_bdy, align_start, align_end
|
||||
|
||||
|
||||
def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
|
||||
dic = {}
|
||||
keys_to_del = []
|
||||
exist_idx = []
|
||||
sp_count = 0
|
||||
add_sp_count = 0
|
||||
for key in word2phns.keys():
|
||||
idx, wrd = key.split('_')
|
||||
if wrd == 'sp':
|
||||
sp_count += 1
|
||||
exist_idx.append(int(idx))
|
||||
else:
|
||||
keys_to_del.append(key)
|
||||
|
||||
for key in keys_to_del:
|
||||
del word2phns[key]
|
||||
|
||||
cur_id = 0
|
||||
for key in tp_word2phns.keys():
|
||||
if cur_id in exist_idx:
|
||||
dic[str(cur_id) + "_sp"] = 'sp'
|
||||
cur_id += 1
|
||||
add_sp_count += 1
|
||||
idx, wrd = key.split('_')
|
||||
dic[str(cur_id) + "_" + wrd] = tp_word2phns[key]
|
||||
cur_id += 1
|
||||
|
||||
if add_sp_count + 1 == sp_count:
|
||||
dic[str(cur_id) + "_sp"] = 'sp'
|
||||
add_sp_count += 1
|
||||
|
||||
assert add_sp_count == sp_count, "sp are not added in dic"
|
||||
return dic
|
||||
|
||||
|
||||
def get_max_idx(dic):
|
||||
return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1]
|
||||
|
||||
|
||||
def get_phns_and_spans(wav_path: str,
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
source_lang: str="english",
|
||||
target_lang: str="english"):
|
||||
is_append = (old_str == new_str[:len(old_str)])
|
||||
old_phns, mfa_start, mfa_end = [], [], []
|
||||
# source
|
||||
if source_lang == "english":
|
||||
intervals, word2phns = alignment(wav_path, old_str)
|
||||
elif source_lang == "chinese":
|
||||
intervals, word2phns = alignment_zh(wav_path, old_str)
|
||||
_, tp_word2phns = words2phns_zh(old_str)
|
||||
|
||||
for key, value in tp_word2phns.items():
|
||||
idx, wrd = key.split('_')
|
||||
cur_val = " ".join(value)
|
||||
tp_word2phns[key] = cur_val
|
||||
|
||||
word2phns = recover_dict(word2phns, tp_word2phns)
|
||||
else:
|
||||
assert source_lang == "chinese" or source_lang == "english", \
|
||||
"source_lang is wrong..."
|
||||
|
||||
for item in intervals:
|
||||
old_phns.append(item[0])
|
||||
mfa_start.append(float(item[1]))
|
||||
mfa_end.append(float(item[2]))
|
||||
# target
|
||||
if is_append and (source_lang != target_lang):
|
||||
cross_lingual_clone = True
|
||||
else:
|
||||
cross_lingual_clone = False
|
||||
|
||||
if cross_lingual_clone:
|
||||
str_origin = new_str[:len(old_str)]
|
||||
str_append = new_str[len(old_str):]
|
||||
|
||||
if target_lang == "chinese":
|
||||
phns_origin, origin_word2phns = words2phns(str_origin)
|
||||
phns_append, append_word2phns_tmp = words2phns_zh(str_append)
|
||||
|
||||
elif target_lang == "english":
|
||||
# 原始句子
|
||||
phns_origin, origin_word2phns = words2phns_zh(str_origin)
|
||||
# clone 句子
|
||||
phns_append, append_word2phns_tmp = words2phns(str_append)
|
||||
else:
|
||||
assert target_lang == "chinese" or target_lang == "english", \
|
||||
"cloning is not support for this language, please check it."
|
||||
|
||||
new_phns = phns_origin + phns_append
|
||||
|
||||
append_word2phns = {}
|
||||
length = len(origin_word2phns)
|
||||
for key, value in append_word2phns_tmp.items():
|
||||
idx, wrd = key.split('_')
|
||||
append_word2phns[str(int(idx) + length) + '_' + wrd] = value
|
||||
new_word2phns = origin_word2phns.copy()
|
||||
new_word2phns.update(append_word2phns)
|
||||
|
||||
else:
|
||||
if source_lang == target_lang and target_lang == "english":
|
||||
new_phns, new_word2phns = words2phns(new_str)
|
||||
elif source_lang == target_lang and target_lang == "chinese":
|
||||
new_phns, new_word2phns = words2phns_zh(new_str)
|
||||
else:
|
||||
assert source_lang == target_lang, \
|
||||
"source language is not same with target language..."
|
||||
|
||||
span_to_repl = [0, len(old_phns) - 1]
|
||||
span_to_add = [0, len(new_phns) - 1]
|
||||
left_idx = 0
|
||||
new_phns_left = []
|
||||
sp_count = 0
|
||||
# find the left different index
|
||||
for key in word2phns.keys():
|
||||
idx, wrd = key.split('_')
|
||||
if wrd == 'sp':
|
||||
sp_count += 1
|
||||
new_phns_left.append('sp')
|
||||
else:
|
||||
idx = str(int(idx) - sp_count)
|
||||
if idx + '_' + wrd in new_word2phns:
|
||||
left_idx += len(new_word2phns[idx + '_' + wrd])
|
||||
new_phns_left.extend(word2phns[key].split())
|
||||
else:
|
||||
span_to_repl[0] = len(new_phns_left)
|
||||
span_to_add[0] = len(new_phns_left)
|
||||
break
|
||||
|
||||
# reverse word2phns and new_word2phns
|
||||
right_idx = 0
|
||||
new_phns_right = []
|
||||
sp_count = 0
|
||||
word2phns_max_idx = get_max_idx(word2phns)
|
||||
new_word2phns_max_idx = get_max_idx(new_word2phns)
|
||||
new_phns_mid = []
|
||||
if is_append:
|
||||
new_phns_right = []
|
||||
new_phns_mid = new_phns[left_idx:]
|
||||
span_to_repl[0] = len(new_phns_left)
|
||||
span_to_add[0] = len(new_phns_left)
|
||||
span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
|
||||
span_to_repl[1] = len(old_phns) - len(new_phns_right)
|
||||
# speech edit
|
||||
else:
|
||||
for key in list(word2phns.keys())[::-1]:
|
||||
idx, wrd = key.split('_')
|
||||
if wrd == 'sp':
|
||||
sp_count += 1
|
||||
new_phns_right = ['sp'] + new_phns_right
|
||||
else:
|
||||
idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx)
|
||||
- sp_count))
|
||||
if idx + '_' + wrd in new_word2phns:
|
||||
right_idx -= len(new_word2phns[idx + '_' + wrd])
|
||||
new_phns_right = word2phns[key].split() + new_phns_right
|
||||
else:
|
||||
span_to_repl[1] = len(old_phns) - len(new_phns_right)
|
||||
new_phns_mid = new_phns[left_idx:right_idx]
|
||||
span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
|
||||
if len(new_phns_mid) == 0:
|
||||
span_to_add[1] = min(span_to_add[1] + 1, len(new_phns))
|
||||
span_to_add[0] = max(0, span_to_add[0] - 1)
|
||||
span_to_repl[0] = max(0, span_to_repl[0] - 1)
|
||||
span_to_repl[1] = min(span_to_repl[1] + 1,
|
||||
len(old_phns))
|
||||
break
|
||||
new_phns = new_phns_left + new_phns_mid + new_phns_right
|
||||
'''
|
||||
For that reason cover should not be given.
|
||||
For that reason cover is impossible to be given.
|
||||
span_to_repl: [17, 23] "should not"
|
||||
span_to_add: [17, 30] "is impossible to"
|
||||
'''
|
||||
return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add
|
||||
|
||||
|
||||
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
|
||||
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
|
||||
def get_dur_adj_factor(orig_dur: List[int],
|
||||
pred_dur: List[int],
|
||||
phns: List[str]):
|
||||
length = 0
|
||||
factor_list = []
|
||||
for orig, pred, phn in zip(orig_dur, pred_dur, phns):
|
||||
if pred == 0 or phn == 'sp':
|
||||
continue
|
||||
else:
|
||||
factor_list.append(orig / pred)
|
||||
factor_list = np.array(factor_list)
|
||||
factor_list.sort()
|
||||
if len(factor_list) < 5:
|
||||
return 1
|
||||
length = 2
|
||||
avg = np.average(factor_list[length:-length])
|
||||
return avg
|
||||
|
||||
|
||||
def prep_feats_with_dur(wav_path: str,
|
||||
source_lang: str="English",
|
||||
target_lang: str="English",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
mask_reconstruct: bool=False,
|
||||
duration_adjust: bool=True,
|
||||
start_end_sp: bool=False,
|
||||
fs: int=24000,
|
||||
hop_length: int=300):
|
||||
'''
|
||||
Returns:
|
||||
np.ndarray: new wav, replace the part to be edited in original wav with 0
|
||||
List[str]: new phones
|
||||
List[float]: mfa start of new wav
|
||||
List[float]: mfa end of new wav
|
||||
List[int]: masked mel boundary of original wav
|
||||
List[int]: masked mel boundary of new wav
|
||||
'''
|
||||
wav_org, _ = librosa.load(wav_path, sr=fs)
|
||||
|
||||
mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans(
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang)
|
||||
|
||||
if start_end_sp:
|
||||
if new_phns[-1] != 'sp':
|
||||
new_phns = new_phns + ['sp']
|
||||
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
|
||||
if target_lang == "english" or target_lang == "chinese":
|
||||
old_durs = eval_durs(old_phns, target_lang=source_lang)
|
||||
else:
|
||||
assert target_lang == "chinese" or target_lang == "english", \
|
||||
"calculate duration_predict is not support for this language..."
|
||||
|
||||
orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)]
|
||||
if '[MASK]' in new_str:
|
||||
new_phns = old_phns
|
||||
span_to_add = span_to_repl
|
||||
d_factor_left = get_dur_adj_factor(
|
||||
orig_dur=orig_old_durs[:span_to_repl[0]],
|
||||
pred_dur=old_durs[:span_to_repl[0]],
|
||||
phns=old_phns[:span_to_repl[0]])
|
||||
d_factor_right = get_dur_adj_factor(
|
||||
orig_dur=orig_old_durs[span_to_repl[1]:],
|
||||
pred_dur=old_durs[span_to_repl[1]:],
|
||||
phns=old_phns[span_to_repl[1]:])
|
||||
d_factor = (d_factor_left + d_factor_right) / 2
|
||||
new_durs_adjusted = [d_factor * i for i in old_durs]
|
||||
else:
|
||||
if duration_adjust:
|
||||
d_factor = get_dur_adj_factor(
|
||||
orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
|
||||
d_factor = d_factor * 1.25
|
||||
else:
|
||||
d_factor = 1
|
||||
|
||||
if target_lang == "english" or target_lang == "chinese":
|
||||
new_durs = eval_durs(new_phns, target_lang=target_lang)
|
||||
else:
|
||||
assert target_lang == "chinese" or target_lang == "english", \
|
||||
"calculate duration_predict is not support for this language..."
|
||||
|
||||
new_durs_adjusted = [d_factor * i for i in new_durs]
|
||||
|
||||
new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]])
|
||||
old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]])
|
||||
dur_offset = new_span_dur_sum - old_span_dur_sum
|
||||
new_mfa_start = mfa_start[:span_to_repl[0]]
|
||||
new_mfa_end = mfa_end[:span_to_repl[0]]
|
||||
for i in new_durs_adjusted[span_to_add[0]:span_to_add[1]]:
|
||||
if len(new_mfa_end) == 0:
|
||||
new_mfa_start.append(0)
|
||||
new_mfa_end.append(i)
|
||||
else:
|
||||
new_mfa_start.append(new_mfa_end[-1])
|
||||
new_mfa_end.append(new_mfa_end[-1] + i)
|
||||
new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]]
|
||||
new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]]
|
||||
|
||||
# 3. get new wav
|
||||
# 在原始句子后拼接
|
||||
if span_to_repl[0] >= len(mfa_start):
|
||||
left_idx = len(wav_org)
|
||||
right_idx = left_idx
|
||||
# 在原始句子中间替换
|
||||
else:
|
||||
left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs))
|
||||
right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs))
|
||||
blank_wav = np.zeros(
|
||||
(int(np.ceil(new_span_dur_sum * fs)), ), dtype=wav_org.dtype)
|
||||
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
|
||||
new_wav = np.concatenate(
|
||||
[wav_org[:left_idx], blank_wav, wav_org[right_idx:]])
|
||||
|
||||
# 4. get old and new mel span to be mask
|
||||
# [92, 92]
|
||||
|
||||
old_span_bdy, mfa_start, mfa_end = get_masked_mel_bdy(
|
||||
mfa_start=mfa_start,
|
||||
mfa_end=mfa_end,
|
||||
fs=fs,
|
||||
hop_length=hop_length,
|
||||
span_to_repl=span_to_repl)
|
||||
# [92, 174]
|
||||
# new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别
|
||||
new_span_bdy, new_mfa_start, new_mfa_end = get_masked_mel_bdy(
|
||||
mfa_start=new_mfa_start,
|
||||
mfa_end=new_mfa_end,
|
||||
fs=fs,
|
||||
hop_length=hop_length,
|
||||
span_to_repl=span_to_add)
|
||||
|
||||
# old_span_bdy, new_span_bdy 是帧级别的范围
|
||||
return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy
|
||||
|
||||
|
||||
def prep_feats(wav_path: str,
|
||||
source_lang: str="english",
|
||||
target_lang: str="english",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
duration_adjust: bool=True,
|
||||
start_end_sp: bool=False,
|
||||
mask_reconstruct: bool=False,
|
||||
fs: int=24000,
|
||||
hop_length: int=300,
|
||||
token_list: List[str]=[]):
|
||||
wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
wav_path=wav_path,
|
||||
duration_adjust=duration_adjust,
|
||||
start_end_sp=start_end_sp,
|
||||
mask_reconstruct=mask_reconstruct,
|
||||
fs=fs,
|
||||
hop_length=hop_length)
|
||||
|
||||
token_to_id = {item: i for i, item in enumerate(token_list)}
|
||||
text = np.array(
|
||||
list(map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns)))
|
||||
span_bdy = np.array(new_span_bdy)
|
||||
|
||||
batch = [('1', {
|
||||
"speech": wav,
|
||||
"align_start": mfa_start,
|
||||
"align_end": mfa_end,
|
||||
"text": text,
|
||||
"span_bdy": span_bdy
|
||||
})]
|
||||
|
||||
return batch, old_span_bdy, new_span_bdy
|
||||
|
||||
|
||||
def decode_with_model(mlm_model: nn.Layer,
|
||||
collate_fn,
|
||||
wav_path: str,
|
||||
source_lang: str="english",
|
||||
target_lang: str="english",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
use_teacher_forcing: bool=False,
|
||||
duration_adjust: bool=True,
|
||||
start_end_sp: bool=False,
|
||||
fs: int=24000,
|
||||
hop_length: int=300,
|
||||
token_list: List[str]=[]):
|
||||
batch, old_span_bdy, new_span_bdy = prep_feats(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
duration_adjust=duration_adjust,
|
||||
start_end_sp=start_end_sp,
|
||||
fs=fs,
|
||||
hop_length=hop_length,
|
||||
token_list=token_list)
|
||||
|
||||
feats = collate_fn(batch)[1]
|
||||
|
||||
if 'text_masked_pos' in feats.keys():
|
||||
feats.pop('text_masked_pos')
|
||||
|
||||
output = mlm_model.inference(
|
||||
text=feats['text'],
|
||||
speech=feats['speech'],
|
||||
masked_pos=feats['masked_pos'],
|
||||
speech_mask=feats['speech_mask'],
|
||||
text_mask=feats['text_mask'],
|
||||
speech_seg_pos=feats['speech_seg_pos'],
|
||||
text_seg_pos=feats['text_seg_pos'],
|
||||
span_bdy=new_span_bdy,
|
||||
use_teacher_forcing=use_teacher_forcing)
|
||||
|
||||
# 拼接音频
|
||||
output_feat = paddle.concat(x=output, axis=0)
|
||||
wav_org, _ = librosa.load(wav_path, sr=fs)
|
||||
return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
|
||||
|
||||
|
||||
def get_mlm_output(wav_path: str,
|
||||
model_name: str="paddle_checkpoint_en",
|
||||
source_lang: str="english",
|
||||
target_lang: str="english",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
use_teacher_forcing: bool=False,
|
||||
duration_adjust: bool=True,
|
||||
start_end_sp: bool=False):
|
||||
mlm_model, train_conf = load_model(model_name)
|
||||
|
||||
collate_fn = build_mlm_collate_fn(
|
||||
sr=train_conf.fs,
|
||||
n_fft=train_conf.n_fft,
|
||||
hop_length=train_conf.n_shift,
|
||||
win_length=train_conf.win_length,
|
||||
n_mels=train_conf.n_mels,
|
||||
fmin=train_conf.fmin,
|
||||
fmax=train_conf.fmax,
|
||||
mlm_prob=train_conf.mlm_prob,
|
||||
mean_phn_span=train_conf.mean_phn_span,
|
||||
seg_emb=train_conf.model['enc_input_layer'] == 'sega_mlm')
|
||||
|
||||
return decode_with_model(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
mlm_model=mlm_model,
|
||||
collate_fn=collate_fn,
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
use_teacher_forcing=use_teacher_forcing,
|
||||
duration_adjust=duration_adjust,
|
||||
start_end_sp=start_end_sp,
|
||||
fs=train_conf.fs,
|
||||
hop_length=train_conf.n_shift,
|
||||
token_list=train_conf.token_list)
|
||||
|
||||
|
||||
def evaluate(uid: str,
|
||||
source_lang: str="english",
|
||||
target_lang: str="english",
|
||||
prefix: os.PathLike="./prompt/dev/",
|
||||
model_name: str="paddle_checkpoint_en",
|
||||
new_str: str="",
|
||||
prompt_decoding: bool=False,
|
||||
task_name: str=None):
|
||||
|
||||
# get origin text and path of origin wav
|
||||
old_str, wav_path = read_data(uid=uid, prefix=prefix)
|
||||
|
||||
if task_name == 'edit':
|
||||
new_str = new_str
|
||||
elif task_name == 'synthesize':
|
||||
new_str = old_str + new_str
|
||||
else:
|
||||
new_str = old_str + ' '.join([ch for ch in new_str if is_chinese(ch)])
|
||||
|
||||
print('new_str is ', new_str)
|
||||
|
||||
results_dict = get_wav(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
model_name=model_name,
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str)
|
||||
return results_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parse config and args
|
||||
args = parse_args()
|
||||
|
||||
data_dict = evaluate(
|
||||
uid=args.uid,
|
||||
source_lang=args.source_lang,
|
||||
target_lang=args.target_lang,
|
||||
prefix=args.prefix,
|
||||
model_name=args.model_name,
|
||||
new_str=args.new_str,
|
||||
task_name=args.task_name)
|
||||
sf.write(args.output_name, data_dict['output'], samplerate=24000)
|
||||
print("finished...")
|
@ -1,97 +0,0 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
|
||||
|
||||
def parse_args():
|
||||
# parse args and config and redirect to train_sp
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Synthesize with acoustic model & vocoder")
|
||||
# acoustic model
|
||||
parser.add_argument(
|
||||
'--am',
|
||||
type=str,
|
||||
default='fastspeech2_csmsc',
|
||||
choices=[
|
||||
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech',
|
||||
'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc',
|
||||
'tacotron2_ljspeech', 'tacotron2_aishell3'
|
||||
],
|
||||
help='Choose acoustic model type of tts task.')
|
||||
parser.add_argument(
|
||||
'--am_config',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Config of acoustic model. Use deault config when it is None.')
|
||||
parser.add_argument(
|
||||
'--am_ckpt',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Checkpoint file of acoustic model.')
|
||||
parser.add_argument(
|
||||
"--am_stat",
|
||||
type=str,
|
||||
default=None,
|
||||
help="mean and standard deviation used to normalize spectrogram when training acoustic model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
|
||||
parser.add_argument(
|
||||
"--tones_dict", type=str, default=None, help="tone vocabulary file.")
|
||||
parser.add_argument(
|
||||
"--speaker_dict", type=str, default=None, help="speaker id map file.")
|
||||
|
||||
# vocoder
|
||||
parser.add_argument(
|
||||
'--voc',
|
||||
type=str,
|
||||
default='pwgan_aishell3',
|
||||
choices=[
|
||||
'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
|
||||
'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc',
|
||||
'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk',
|
||||
'style_melgan_csmsc'
|
||||
],
|
||||
help='Choose vocoder type of tts task.')
|
||||
parser.add_argument(
|
||||
'--voc_config',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Config of voc. Use deault config when it is None.')
|
||||
parser.add_argument(
|
||||
'--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
|
||||
parser.add_argument(
|
||||
"--voc_stat",
|
||||
type=str,
|
||||
default=None,
|
||||
help="mean and standard deviation used to normalize spectrogram when training voc."
|
||||
)
|
||||
# other
|
||||
parser.add_argument(
|
||||
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
||||
|
||||
parser.add_argument("--model_name", type=str, help="model name")
|
||||
parser.add_argument("--uid", type=str, help="uid")
|
||||
parser.add_argument("--new_str", type=str, help="new string")
|
||||
parser.add_argument("--prefix", type=str, help="prefix")
|
||||
parser.add_argument(
|
||||
"--source_lang", type=str, default="english", help="source language")
|
||||
parser.add_argument(
|
||||
"--target_lang", type=str, default="english", help="target language")
|
||||
parser.add_argument("--output_name", type=str, help="output name")
|
||||
parser.add_argument("--task_name", type=str, help="task name")
|
||||
|
||||
# pre
|
||||
args = parser.parse_args()
|
||||
return args
|
@ -1,175 +0,0 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import yaml
|
||||
from sedit_arg_parser import parse_args
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.exps.syn_utils import get_am_inference
|
||||
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
|
||||
|
||||
|
||||
def read_2col_text(path: Union[Path, str]) -> Dict[str, str]:
|
||||
"""Read a text file having 2 column as dict object.
|
||||
|
||||
Examples:
|
||||
wav.scp:
|
||||
key1 /some/path/a.wav
|
||||
key2 /some/path/b.wav
|
||||
|
||||
>>> read_2col_text('wav.scp')
|
||||
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
|
||||
|
||||
"""
|
||||
|
||||
data = {}
|
||||
with Path(path).open("r", encoding="utf-8") as f:
|
||||
for linenum, line in enumerate(f, 1):
|
||||
sps = line.rstrip().split(maxsplit=1)
|
||||
if len(sps) == 1:
|
||||
k, v = sps[0], ""
|
||||
else:
|
||||
k, v = sps
|
||||
if k in data:
|
||||
raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
|
||||
data[k] = v
|
||||
return data
|
||||
|
||||
|
||||
def load_num_sequence_text(path: Union[Path, str], loader_type: str="csv_int"
|
||||
) -> Dict[str, List[Union[float, int]]]:
|
||||
"""Read a text file indicating sequences of number
|
||||
|
||||
Examples:
|
||||
key1 1 2 3
|
||||
key2 34 5 6
|
||||
|
||||
>>> d = load_num_sequence_text('text')
|
||||
>>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3]))
|
||||
"""
|
||||
if loader_type == "text_int":
|
||||
delimiter = " "
|
||||
dtype = int
|
||||
elif loader_type == "text_float":
|
||||
delimiter = " "
|
||||
dtype = float
|
||||
elif loader_type == "csv_int":
|
||||
delimiter = ","
|
||||
dtype = int
|
||||
elif loader_type == "csv_float":
|
||||
delimiter = ","
|
||||
dtype = float
|
||||
else:
|
||||
raise ValueError(f"Not supported loader_type={loader_type}")
|
||||
|
||||
# path looks like:
|
||||
# utta 1,0
|
||||
# uttb 3,4,5
|
||||
# -> return {'utta': np.ndarray([1, 0]),
|
||||
# 'uttb': np.ndarray([3, 4, 5])}
|
||||
d = read_2column_text(path)
|
||||
# Using for-loop instead of dict-comprehension for debuggability
|
||||
retval = {}
|
||||
for k, v in d.items():
|
||||
try:
|
||||
retval[k] = [dtype(i) for i in v.split(delimiter)]
|
||||
except TypeError:
|
||||
print(f'Error happened with path="{path}", id="{k}", value="{v}"')
|
||||
raise
|
||||
return retval
|
||||
|
||||
|
||||
def is_chinese(ch):
|
||||
if u'\u4e00' <= ch <= u'\u9fff':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def get_voc_out(mel):
|
||||
# vocoder
|
||||
args = parse_args()
|
||||
with open(args.voc_config) as f:
|
||||
voc_config = CfgNode(yaml.safe_load(f))
|
||||
voc_inference = get_voc_inference(
|
||||
voc=args.voc,
|
||||
voc_config=voc_config,
|
||||
voc_ckpt=args.voc_ckpt,
|
||||
voc_stat=args.voc_stat)
|
||||
|
||||
with paddle.no_grad():
|
||||
wav = voc_inference(mel)
|
||||
return np.squeeze(wav)
|
||||
|
||||
|
||||
def eval_durs(phns, target_lang="chinese", fs=24000, hop_length=300):
|
||||
args = parse_args()
|
||||
|
||||
if target_lang == 'english':
|
||||
args.am = "fastspeech2_ljspeech"
|
||||
args.am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
|
||||
args.am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
|
||||
args.am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
|
||||
args.phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
|
||||
|
||||
elif target_lang == 'chinese':
|
||||
args.am = "fastspeech2_csmsc"
|
||||
args.am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
|
||||
args.am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
|
||||
args.am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
|
||||
args.phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
|
||||
|
||||
if args.ngpu == 0:
|
||||
paddle.set_device("cpu")
|
||||
elif args.ngpu > 0:
|
||||
paddle.set_device("gpu")
|
||||
else:
|
||||
print("ngpu should >= 0 !")
|
||||
|
||||
# Init body.
|
||||
with open(args.am_config) as f:
|
||||
am_config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
am_inference, am = get_am_inference(
|
||||
am=args.am,
|
||||
am_config=am_config,
|
||||
am_ckpt=args.am_ckpt,
|
||||
am_stat=args.am_stat,
|
||||
phones_dict=args.phones_dict,
|
||||
tones_dict=args.tones_dict,
|
||||
speaker_dict=args.speaker_dict,
|
||||
return_am=True)
|
||||
|
||||
vocab_phones = {}
|
||||
with open(args.phones_dict, "r") as f:
|
||||
phn_id = [line.strip().split() for line in f.readlines()]
|
||||
for tone, id in phn_id:
|
||||
vocab_phones[tone] = int(id)
|
||||
vocab_size = len(vocab_phones)
|
||||
phonemes = [phn if phn in vocab_phones else "sp" for phn in phns]
|
||||
|
||||
phone_ids = [vocab_phones[item] for item in phonemes]
|
||||
phone_ids.append(vocab_size - 1)
|
||||
phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64))
|
||||
_, d_outs, _, _ = am.inference(phone_ids, spk_id=None, spk_emb=None)
|
||||
pre_d_outs = d_outs
|
||||
phu_durs_new = pre_d_outs * hop_length / fs
|
||||
phu_durs_new = phu_durs_new.tolist()[:-1]
|
||||
return phu_durs_new
|
@ -1,13 +0,0 @@
|
||||
#!/bin/bash
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
MODEL=ernie_sat
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
|
@ -1,3 +0,0 @@
|
||||
p243_new For that reason cover should not be given.
|
||||
Prompt_003_new This was not the show for me.
|
||||
p299_096 We are trying to establish a date.
|
@ -1,3 +0,0 @@
|
||||
p243_new ../../prompt_wav/p243_313.wav
|
||||
Prompt_003_new ../../prompt_wav/this_was_not_the_show_for_me.wav
|
||||
p299_096 ../../prompt_wav/p299_096.wav
|
@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
# en --> zh 的 语音合成
|
||||
# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好'
|
||||
# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
|
||||
|
||||
python local/inference.py \
|
||||
--task_name=cross-lingual_clone \
|
||||
--model_name=paddle_checkpoint_dual_mask_enzh \
|
||||
--uid=Prompt_003_new \
|
||||
--new_str='今天天气很好.' \
|
||||
--prefix='./prompt/dev/' \
|
||||
--source_lang=english \
|
||||
--target_lang=chinese \
|
||||
--output_name=pred_clone.wav \
|
||||
--voc=pwgan_aishell3 \
|
||||
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
|
||||
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
|
||||
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \
|
||||
--am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \
|
||||
--am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \
|
||||
--phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt
|
@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
# en --> zh 的 语音合成
|
||||
# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好'
|
||||
# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
|
||||
|
||||
python local/inference_new.py \
|
||||
--task_name=cross-lingual_clone \
|
||||
--model_name=paddle_checkpoint_dual_mask_enzh \
|
||||
--uid=Prompt_003_new \
|
||||
--new_str='今天天气很好.' \
|
||||
--prefix='./prompt/dev/' \
|
||||
--source_lang=english \
|
||||
--target_lang=chinese \
|
||||
--output_name=pred_clone.wav \
|
||||
--voc=pwgan_aishell3 \
|
||||
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
|
||||
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
|
||||
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \
|
||||
--am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \
|
||||
--am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \
|
||||
--phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt
|
@ -1,26 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
# 纯英文的语音合成
|
||||
# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.'
|
||||
|
||||
python local/inference.py \
|
||||
--task_name=synthesize \
|
||||
--model_name=paddle_checkpoint_en \
|
||||
--uid=p299_096 \
|
||||
--new_str='I enjoy my life, do you?' \
|
||||
--prefix='./prompt/dev/' \
|
||||
--source_lang=english \
|
||||
--target_lang=english \
|
||||
--output_name=pred_gen.wav \
|
||||
--voc=pwgan_aishell3 \
|
||||
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
|
||||
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
|
||||
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
|
||||
--am=fastspeech2_ljspeech \
|
||||
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
|
||||
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
|
||||
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
|
||||
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
|
@ -1,26 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
# 纯英文的语音合成
|
||||
# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.'
|
||||
|
||||
python local/inference_new.py \
|
||||
--task_name=synthesize \
|
||||
--model_name=paddle_checkpoint_en \
|
||||
--uid=p299_096 \
|
||||
--new_str='I enjoy my life, do you?' \
|
||||
--prefix='./prompt/dev/' \
|
||||
--source_lang=english \
|
||||
--target_lang=english \
|
||||
--output_name=pred_gen.wav \
|
||||
--voc=pwgan_aishell3 \
|
||||
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
|
||||
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
|
||||
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
|
||||
--am=fastspeech2_ljspeech \
|
||||
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
|
||||
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
|
||||
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
|
||||
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
|
@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
# 纯英文的语音编辑
|
||||
# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音
|
||||
# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作
|
||||
|
||||
python local/inference.py \
|
||||
--task_name=edit \
|
||||
--model_name=paddle_checkpoint_en \
|
||||
--uid=p243_new \
|
||||
--new_str='for that reason cover is impossible to be given.' \
|
||||
--prefix='./prompt/dev/' \
|
||||
--source_lang=english \
|
||||
--target_lang=english \
|
||||
--output_name=pred_edit.wav \
|
||||
--voc=pwgan_aishell3 \
|
||||
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
|
||||
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
|
||||
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
|
||||
--am=fastspeech2_ljspeech \
|
||||
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
|
||||
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
|
||||
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
|
||||
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
|
@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
# 纯英文的语音编辑
|
||||
# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音
|
||||
# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作
|
||||
|
||||
python local/inference_new.py \
|
||||
--task_name=edit \
|
||||
--model_name=paddle_checkpoint_en \
|
||||
--uid=p243_new \
|
||||
--new_str='for that reason cover is impossible to be given.' \
|
||||
--prefix='./prompt/dev/' \
|
||||
--source_lang=english \
|
||||
--target_lang=english \
|
||||
--output_name=pred_edit.wav \
|
||||
--voc=pwgan_aishell3 \
|
||||
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
|
||||
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
|
||||
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
|
||||
--am=fastspeech2_ljspeech \
|
||||
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
|
||||
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
|
||||
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
|
||||
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
|
@ -1,6 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
rm -rf *.wav
|
||||
./run_sedit_en.sh # 语音编辑任务(英文)
|
||||
./run_gen_en.sh # 个性化语音合成任务(英文)
|
||||
./run_clone_en_to_zh.sh # 跨语言语音合成任务(英文到中文的语音克隆)
|
@ -1,6 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
rm -rf *.wav
|
||||
./run_sedit_en_new.sh # 语音编辑任务(英文)
|
||||
./run_gen_en_new.sh # 个性化语音合成任务(英文)
|
||||
./run_clone_en_to_zh_new.sh # 跨语言语音合成任务(英文到中文的语音克隆)
|
@ -1,579 +0,0 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
import yaml
|
||||
from paddle import nn
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.modules.activation import get_activation
|
||||
from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule
|
||||
from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer
|
||||
from paddlespeech.t2s.modules.layer_norm import LayerNorm
|
||||
from paddlespeech.t2s.modules.masked_fill import masked_fill
|
||||
from paddlespeech.t2s.modules.nets_utils import initialize
|
||||
from paddlespeech.t2s.modules.tacotron2.decoder import Postnet
|
||||
from paddlespeech.t2s.modules.transformer.attention import LegacyRelPositionMultiHeadedAttention
|
||||
from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention
|
||||
from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention
|
||||
from paddlespeech.t2s.modules.transformer.embedding import LegacyRelPositionalEncoding
|
||||
from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding
|
||||
from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding
|
||||
from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding
|
||||
from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear
|
||||
from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredConv1d
|
||||
from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
||||
from paddlespeech.t2s.modules.transformer.repeat import repeat
|
||||
from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling
|
||||
|
||||
|
||||
# MLM -> Mask Language Model
|
||||
class mySequential(nn.Sequential):
|
||||
def forward(self, *inputs):
|
||||
for module in self._sub_layers.values():
|
||||
if type(inputs) == tuple:
|
||||
inputs = module(*inputs)
|
||||
else:
|
||||
inputs = module(inputs)
|
||||
return inputs
|
||||
|
||||
|
||||
class MaskInputLayer(nn.Layer):
|
||||
def __init__(self, out_features: int) -> None:
|
||||
super().__init__()
|
||||
self.mask_feature = paddle.create_parameter(
|
||||
shape=(1, 1, out_features),
|
||||
dtype=paddle.float32,
|
||||
default_initializer=paddle.nn.initializer.Assign(
|
||||
paddle.normal(shape=(1, 1, out_features))))
|
||||
|
||||
def forward(self, input: paddle.Tensor,
|
||||
masked_pos: paddle.Tensor=None) -> paddle.Tensor:
|
||||
masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input)
|
||||
masked_input = masked_fill(input, masked_pos, 0) + masked_fill(
|
||||
paddle.expand_as(self.mask_feature, input), ~masked_pos, 0)
|
||||
return masked_input
|
||||
|
||||
|
||||
class MLMEncoder(nn.Layer):
|
||||
"""Conformer encoder module.
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
attention_dim (int): Dimension of attention.
|
||||
attention_heads (int): The number of heads of multi head attention.
|
||||
linear_units (int): The number of units of position-wise feed forward.
|
||||
num_blocks (int): The number of decoder blocks.
|
||||
dropout_rate (float): Dropout rate.
|
||||
positional_dropout_rate (float): Dropout rate after adding positional encoding.
|
||||
attention_dropout_rate (float): Dropout rate in attention.
|
||||
input_layer (Union[str, paddle.nn.Layer]): Input layer type.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
|
||||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
|
||||
macaron_style (bool): Whether to use macaron style for positionwise layer.
|
||||
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||||
selfattention_layer_type (str): Encoder attention layer type.
|
||||
activation_type (str): Encoder activation function type.
|
||||
use_cnn_module (bool): Whether to use convolution module.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
cnn_module_kernel (int): Kernerl size of convolution module.
|
||||
padding_idx (int): Padding idx for input_layer=embed.
|
||||
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
idim: int,
|
||||
vocab_size: int=0,
|
||||
pre_speech_layer: int=0,
|
||||
attention_dim: int=256,
|
||||
attention_heads: int=4,
|
||||
linear_units: int=2048,
|
||||
num_blocks: int=6,
|
||||
dropout_rate: float=0.1,
|
||||
positional_dropout_rate: float=0.1,
|
||||
attention_dropout_rate: float=0.0,
|
||||
input_layer: str="conv2d",
|
||||
normalize_before: bool=True,
|
||||
concat_after: bool=False,
|
||||
positionwise_layer_type: str="linear",
|
||||
positionwise_conv_kernel_size: int=1,
|
||||
macaron_style: bool=False,
|
||||
pos_enc_layer_type: str="abs_pos",
|
||||
selfattention_layer_type: str="selfattn",
|
||||
activation_type: str="swish",
|
||||
use_cnn_module: bool=False,
|
||||
zero_triu: bool=False,
|
||||
cnn_module_kernel: int=31,
|
||||
padding_idx: int=-1,
|
||||
stochastic_depth_rate: float=0.0,
|
||||
text_masking: bool=False):
|
||||
"""Construct an Encoder object."""
|
||||
super().__init__()
|
||||
self._output_size = attention_dim
|
||||
self.text_masking = text_masking
|
||||
if self.text_masking:
|
||||
self.text_masking_layer = MaskInputLayer(attention_dim)
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
elif pos_enc_layer_type == "legacy_rel_pos":
|
||||
pos_enc_class = LegacyRelPositionalEncoding
|
||||
assert selfattention_layer_type == "legacy_rel_selfattn"
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
self.conv_subsampling_factor = 1
|
||||
if input_layer == "linear":
|
||||
self.embed = nn.Sequential(
|
||||
nn.Linear(idim, attention_dim),
|
||||
nn.LayerNorm(attention_dim),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate), )
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate), )
|
||||
self.conv_subsampling_factor = 4
|
||||
elif input_layer == "embed":
|
||||
self.embed = nn.Sequential(
|
||||
nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate), )
|
||||
elif input_layer == "mlm":
|
||||
self.segment_emb = None
|
||||
self.speech_embed = mySequential(
|
||||
MaskInputLayer(idim),
|
||||
nn.Linear(idim, attention_dim),
|
||||
nn.LayerNorm(attention_dim),
|
||||
nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate))
|
||||
self.text_embed = nn.Sequential(
|
||||
nn.Embedding(
|
||||
vocab_size, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate), )
|
||||
elif input_layer == "sega_mlm":
|
||||
self.segment_emb = nn.Embedding(
|
||||
500, attention_dim, padding_idx=padding_idx)
|
||||
self.speech_embed = mySequential(
|
||||
MaskInputLayer(idim),
|
||||
nn.Linear(idim, attention_dim),
|
||||
nn.LayerNorm(attention_dim),
|
||||
nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate))
|
||||
self.text_embed = nn.Sequential(
|
||||
nn.Embedding(
|
||||
vocab_size, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate), )
|
||||
elif isinstance(input_layer, nn.Layer):
|
||||
self.embed = nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate), )
|
||||
elif input_layer is None:
|
||||
self.embed = nn.Sequential(
|
||||
pos_enc_class(attention_dim, positional_dropout_rate))
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
# self-attention module definition
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (attention_heads, attention_dim,
|
||||
attention_dropout_rate, )
|
||||
elif selfattention_layer_type == "legacy_rel_selfattn":
|
||||
assert pos_enc_layer_type == "legacy_rel_pos"
|
||||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (attention_heads, attention_dim,
|
||||
attention_dropout_rate, )
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (attention_heads, attention_dim,
|
||||
attention_dropout_rate, zero_triu, )
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " +
|
||||
selfattention_layer_type)
|
||||
|
||||
# feed-forward module definition
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (attention_dim, linear_units,
|
||||
dropout_rate, activation, )
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (attention_dim, linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate, )
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (attention_dim, linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate, )
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
# convolution module definition
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
stochastic_depth_rate * float(1 + lnum) / num_blocks, ), )
|
||||
self.pre_speech_layer = pre_speech_layer
|
||||
self.pre_speech_encoders = repeat(
|
||||
self.pre_speech_layer,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer, ),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
|
||||
def forward(self,
|
||||
speech: paddle.Tensor,
|
||||
text: paddle.Tensor,
|
||||
masked_pos: paddle.Tensor,
|
||||
speech_mask: paddle.Tensor=None,
|
||||
text_mask: paddle.Tensor=None,
|
||||
speech_seg_pos: paddle.Tensor=None,
|
||||
text_seg_pos: paddle.Tensor=None):
|
||||
"""Encode input sequence.
|
||||
|
||||
"""
|
||||
if masked_pos is not None:
|
||||
speech = self.speech_embed(speech, masked_pos)
|
||||
else:
|
||||
speech = self.speech_embed(speech)
|
||||
if text is not None:
|
||||
text = self.text_embed(text)
|
||||
if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb:
|
||||
speech_seg_emb = self.segment_emb(speech_seg_pos)
|
||||
text_seg_emb = self.segment_emb(text_seg_pos)
|
||||
text = (text[0] + text_seg_emb, text[1])
|
||||
speech = (speech[0] + speech_seg_emb, speech[1])
|
||||
if self.pre_speech_encoders:
|
||||
speech, _ = self.pre_speech_encoders(speech, speech_mask)
|
||||
|
||||
if text is not None:
|
||||
xs = paddle.concat([speech[0], text[0]], axis=1)
|
||||
xs_pos_emb = paddle.concat([speech[1], text[1]], axis=1)
|
||||
masks = paddle.concat([speech_mask, text_mask], axis=-1)
|
||||
else:
|
||||
xs = speech[0]
|
||||
xs_pos_emb = speech[1]
|
||||
masks = speech_mask
|
||||
|
||||
xs, masks = self.encoders((xs, xs_pos_emb), masks)
|
||||
|
||||
if isinstance(xs, tuple):
|
||||
xs = xs[0]
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
|
||||
return xs, masks
|
||||
|
||||
|
||||
class MLMDecoder(MLMEncoder):
|
||||
def forward(self, xs: paddle.Tensor, masks: paddle.Tensor):
|
||||
"""Encode input sequence.
|
||||
|
||||
Args:
|
||||
xs (paddle.Tensor): Input tensor (#batch, time, idim).
|
||||
masks (paddle.Tensor): Mask tensor (#batch, time).
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: Output tensor (#batch, time, attention_dim).
|
||||
paddle.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
xs = self.embed(xs)
|
||||
xs, masks = self.encoders(xs, masks)
|
||||
|
||||
if isinstance(xs, tuple):
|
||||
xs = xs[0]
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
|
||||
return xs, masks
|
||||
|
||||
|
||||
# encoder and decoder is nn.Layer, not str
|
||||
class MLM(nn.Layer):
|
||||
def __init__(self,
|
||||
odim: int,
|
||||
encoder: nn.Layer,
|
||||
decoder: Optional[nn.Layer],
|
||||
postnet_layers: int=0,
|
||||
postnet_chans: int=0,
|
||||
postnet_filts: int=0,
|
||||
text_masking: bool=False):
|
||||
|
||||
super().__init__()
|
||||
self.odim = odim
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.vocab_size = encoder.text_embed[0]._num_embeddings
|
||||
|
||||
if self.decoder is None or not (hasattr(self.decoder,
|
||||
'output_layer') and
|
||||
self.decoder.output_layer is not None):
|
||||
self.sfc = nn.Linear(self.encoder._output_size, odim)
|
||||
else:
|
||||
self.sfc = None
|
||||
if text_masking:
|
||||
self.text_sfc = nn.Linear(
|
||||
self.encoder.text_embed[0]._embedding_dim,
|
||||
self.vocab_size,
|
||||
weight_attr=self.encoder.text_embed[0]._weight_attr)
|
||||
else:
|
||||
self.text_sfc = None
|
||||
|
||||
self.postnet = (None if postnet_layers == 0 else Postnet(
|
||||
idim=self.encoder._output_size,
|
||||
odim=odim,
|
||||
n_layers=postnet_layers,
|
||||
n_chans=postnet_chans,
|
||||
n_filts=postnet_filts,
|
||||
use_batch_norm=True,
|
||||
dropout_rate=0.5, ))
|
||||
|
||||
def inference(
|
||||
self,
|
||||
speech: paddle.Tensor,
|
||||
text: paddle.Tensor,
|
||||
masked_pos: paddle.Tensor,
|
||||
speech_mask: paddle.Tensor,
|
||||
text_mask: paddle.Tensor,
|
||||
speech_seg_pos: paddle.Tensor,
|
||||
text_seg_pos: paddle.Tensor,
|
||||
span_bdy: List[int],
|
||||
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
|
||||
'''
|
||||
Args:
|
||||
speech (paddle.Tensor): input speech (1, Tmax, D).
|
||||
text (paddle.Tensor): input text (1, Tmax2).
|
||||
masked_pos (paddle.Tensor): masked position of input speech (1, Tmax)
|
||||
speech_mask (paddle.Tensor): mask of speech (1, 1, Tmax).
|
||||
text_mask (paddle.Tensor): mask of text (1, 1, Tmax2).
|
||||
speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (1, Tmax).
|
||||
text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (1, Tmax2).
|
||||
span_bdy (List[int]): masked mel boundary of input speech (2,)
|
||||
use_teacher_forcing (bool): whether to use teacher forcing
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
eg:
|
||||
[Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])]
|
||||
'''
|
||||
|
||||
z_cache = None
|
||||
if use_teacher_forcing:
|
||||
before_outs, zs, *_ = self.forward(
|
||||
speech=speech,
|
||||
text=text,
|
||||
masked_pos=masked_pos,
|
||||
speech_mask=speech_mask,
|
||||
text_mask=text_mask,
|
||||
speech_seg_pos=speech_seg_pos,
|
||||
text_seg_pos=text_seg_pos)
|
||||
if zs is None:
|
||||
zs = before_outs
|
||||
|
||||
speech = speech.squeeze(0)
|
||||
outs = [speech[:span_bdy[0]]]
|
||||
outs += [zs[0][span_bdy[0]:span_bdy[1]]]
|
||||
outs += [speech[span_bdy[1]:]]
|
||||
return outs
|
||||
return None
|
||||
|
||||
|
||||
class MLMEncAsDecoder(MLM):
|
||||
def forward(self,
|
||||
speech: paddle.Tensor,
|
||||
text: paddle.Tensor,
|
||||
masked_pos: paddle.Tensor,
|
||||
speech_mask: paddle.Tensor,
|
||||
text_mask: paddle.Tensor,
|
||||
speech_seg_pos: paddle.Tensor,
|
||||
text_seg_pos: paddle.Tensor):
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
encoder_out, h_masks = self.encoder(
|
||||
speech=speech,
|
||||
text=text,
|
||||
masked_pos=masked_pos,
|
||||
speech_mask=speech_mask,
|
||||
text_mask=text_mask,
|
||||
speech_seg_pos=speech_seg_pos,
|
||||
text_seg_pos=text_seg_pos)
|
||||
if self.decoder is not None:
|
||||
zs, _ = self.decoder(encoder_out, h_masks)
|
||||
else:
|
||||
zs = encoder_out
|
||||
speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
|
||||
if self.sfc is not None:
|
||||
before_outs = paddle.reshape(
|
||||
self.sfc(speech_hidden_states),
|
||||
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
|
||||
else:
|
||||
before_outs = speech_hidden_states
|
||||
if self.postnet is not None:
|
||||
after_outs = before_outs + paddle.transpose(
|
||||
self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
|
||||
[0, 2, 1])
|
||||
else:
|
||||
after_outs = None
|
||||
return before_outs, after_outs, None
|
||||
|
||||
|
||||
class MLMDualMaksing(MLM):
|
||||
def forward(self,
|
||||
speech: paddle.Tensor,
|
||||
text: paddle.Tensor,
|
||||
masked_pos: paddle.Tensor,
|
||||
speech_mask: paddle.Tensor,
|
||||
text_mask: paddle.Tensor,
|
||||
speech_seg_pos: paddle.Tensor,
|
||||
text_seg_pos: paddle.Tensor):
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
encoder_out, h_masks = self.encoder(
|
||||
speech=speech,
|
||||
text=text,
|
||||
masked_pos=masked_pos,
|
||||
speech_mask=speech_mask,
|
||||
text_mask=text_mask,
|
||||
speech_seg_pos=speech_seg_pos,
|
||||
text_seg_pos=text_seg_pos)
|
||||
if self.decoder is not None:
|
||||
zs, _ = self.decoder(encoder_out, h_masks)
|
||||
else:
|
||||
zs = encoder_out
|
||||
speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
|
||||
if self.text_sfc:
|
||||
text_hiddent_states = zs[:, paddle.shape(speech)[1]:, :]
|
||||
text_outs = paddle.reshape(
|
||||
self.text_sfc(text_hiddent_states),
|
||||
(paddle.shape(text_hiddent_states)[0], -1, self.vocab_size))
|
||||
if self.sfc is not None:
|
||||
before_outs = paddle.reshape(
|
||||
self.sfc(speech_hidden_states),
|
||||
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
|
||||
else:
|
||||
before_outs = speech_hidden_states
|
||||
if self.postnet is not None:
|
||||
after_outs = before_outs + paddle.transpose(
|
||||
self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
|
||||
[0, 2, 1])
|
||||
else:
|
||||
after_outs = None
|
||||
return before_outs, after_outs, text_outs
|
||||
|
||||
|
||||
def build_model_from_file(config_file, model_file):
|
||||
|
||||
state_dict = paddle.load(model_file)
|
||||
model_class = MLMDualMaksing if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \
|
||||
else MLMEncAsDecoder
|
||||
|
||||
# 构建模型
|
||||
with open(config_file) as f:
|
||||
conf = CfgNode(yaml.safe_load(f))
|
||||
model = build_model(conf, model_class)
|
||||
model.set_state_dict(state_dict)
|
||||
return model, conf
|
||||
|
||||
|
||||
# select encoder and decoder here
|
||||
def build_model(args: argparse.Namespace, model_class=MLMEncAsDecoder) -> MLM:
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
|
||||
# Overwriting token_list to keep it as "portable".
|
||||
args.token_list = list(token_list)
|
||||
elif isinstance(args.token_list, (tuple, list)):
|
||||
token_list = list(args.token_list)
|
||||
else:
|
||||
raise RuntimeError("token_list must be str or list")
|
||||
|
||||
vocab_size = len(token_list)
|
||||
odim = 80
|
||||
|
||||
# Encoder
|
||||
encoder_class = MLMEncoder
|
||||
|
||||
if 'text_masking' in args.model_conf.keys() and args.model_conf[
|
||||
'text_masking']:
|
||||
args.encoder_conf['text_masking'] = True
|
||||
else:
|
||||
args.encoder_conf['text_masking'] = False
|
||||
|
||||
encoder = encoder_class(
|
||||
args.input_size, vocab_size=vocab_size, **args.encoder_conf)
|
||||
|
||||
# Decoder
|
||||
if args.decoder != 'no_decoder':
|
||||
decoder_class = MLMDecoder
|
||||
decoder = decoder_class(
|
||||
idim=0,
|
||||
input_layer=None,
|
||||
**args.decoder_conf, )
|
||||
else:
|
||||
decoder = None
|
||||
|
||||
# Build model
|
||||
model = model_class(
|
||||
odim=odim,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
**args.model_conf, )
|
||||
|
||||
# Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
return model
|
Loading…
Reference in new issue