You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
602 lines
21 KiB
602 lines
21 KiB
2 years ago
|
#!/usr/bin/env python3
|
||
|
import os
|
||
|
import random
|
||
|
from typing import Dict
|
||
|
from typing import List
|
||
|
|
||
|
import librosa
|
||
|
import numpy as np
|
||
|
import paddle
|
||
|
import soundfile as sf
|
||
|
from align import alignment
|
||
|
from align import alignment_zh
|
||
|
from align import words2phns
|
||
|
from align import words2phns_zh
|
||
|
from paddle import nn
|
||
|
from sedit_arg_parser import parse_args
|
||
|
from utils import eval_durs
|
||
|
from utils import get_voc_out
|
||
|
from utils import is_chinese
|
||
|
from utils import load_num_sequence_text
|
||
|
from utils import read_2col_text
|
||
|
|
||
|
from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn
|
||
|
from paddlespeech.t2s.models.ernie_sat.mlm import build_model_from_file
|
||
|
|
||
|
random.seed(0)
|
||
|
np.random.seed(0)
|
||
|
|
||
|
|
||
|
def get_wav(wav_path: str,
|
||
|
source_lang: str='english',
|
||
|
target_lang: str='english',
|
||
|
model_name: str="paddle_checkpoint_en",
|
||
|
old_str: str="",
|
||
|
new_str: str="",
|
||
|
non_autoreg: bool=True):
|
||
|
wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output(
|
||
|
source_lang=source_lang,
|
||
|
target_lang=target_lang,
|
||
|
model_name=model_name,
|
||
|
wav_path=wav_path,
|
||
|
old_str=old_str,
|
||
|
new_str=new_str,
|
||
|
use_teacher_forcing=non_autoreg)
|
||
|
|
||
|
masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
|
||
|
|
||
|
alt_wav = get_voc_out(masked_feat)
|
||
|
|
||
|
old_time_bdy = [hop_length * x for x in old_span_bdy]
|
||
|
|
||
|
wav_replaced = np.concatenate(
|
||
|
[wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]])
|
||
|
|
||
|
data_dict = {"origin": wav_org, "output": wav_replaced}
|
||
|
|
||
|
return data_dict
|
||
|
|
||
|
|
||
|
def load_model(model_name: str="paddle_checkpoint_en"):
|
||
|
config_path = './pretrained_model/{}/config.yaml'.format(model_name)
|
||
|
model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
|
||
|
mlm_model, conf = build_model_from_file(
|
||
|
config_file=config_path, model_file=model_path)
|
||
|
return mlm_model, conf
|
||
|
|
||
|
|
||
|
def read_data(uid: str, prefix: os.PathLike):
|
||
|
# 获取 uid 对应的文本
|
||
|
mfa_text = read_2col_text(prefix + '/text')[uid]
|
||
|
# 获取 uid 对应的音频路径
|
||
|
mfa_wav_path = read_2col_text(prefix + '/wav.scp')[uid]
|
||
|
if not os.path.isabs(mfa_wav_path):
|
||
|
mfa_wav_path = prefix + mfa_wav_path
|
||
|
return mfa_text, mfa_wav_path
|
||
|
|
||
|
|
||
|
def get_align_data(uid: str, prefix: os.PathLike):
|
||
|
mfa_path = prefix + "mfa_"
|
||
|
mfa_text = read_2col_text(mfa_path + 'text')[uid]
|
||
|
mfa_start = load_num_sequence_text(
|
||
|
mfa_path + 'start', loader_type='text_float')[uid]
|
||
|
mfa_end = load_num_sequence_text(
|
||
|
mfa_path + 'end', loader_type='text_float')[uid]
|
||
|
mfa_wav_path = read_2col_text(mfa_path + 'wav.scp')[uid]
|
||
|
return mfa_text, mfa_start, mfa_end, mfa_wav_path
|
||
|
|
||
|
|
||
|
# 获取需要被 mask 的 mel 帧的范围
|
||
|
def get_masked_mel_bdy(mfa_start: List[float],
|
||
|
mfa_end: List[float],
|
||
|
fs: int,
|
||
|
hop_length: int,
|
||
|
span_to_repl: List[List[int]]):
|
||
|
align_start = np.array(mfa_start)
|
||
|
align_end = np.array(mfa_end)
|
||
|
align_start = np.floor(fs * align_start / hop_length).astype('int')
|
||
|
align_end = np.floor(fs * align_end / hop_length).astype('int')
|
||
|
if span_to_repl[0] >= len(mfa_start):
|
||
|
span_bdy = [align_end[-1], align_end[-1]]
|
||
|
else:
|
||
|
span_bdy = [
|
||
|
align_start[span_to_repl[0]], align_end[span_to_repl[1] - 1]
|
||
|
]
|
||
|
return span_bdy, align_start, align_end
|
||
|
|
||
|
|
||
|
def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
|
||
|
dic = {}
|
||
|
keys_to_del = []
|
||
|
exist_idx = []
|
||
|
sp_count = 0
|
||
|
add_sp_count = 0
|
||
|
for key in word2phns.keys():
|
||
|
idx, wrd = key.split('_')
|
||
|
if wrd == 'sp':
|
||
|
sp_count += 1
|
||
|
exist_idx.append(int(idx))
|
||
|
else:
|
||
|
keys_to_del.append(key)
|
||
|
|
||
|
for key in keys_to_del:
|
||
|
del word2phns[key]
|
||
|
|
||
|
cur_id = 0
|
||
|
for key in tp_word2phns.keys():
|
||
|
if cur_id in exist_idx:
|
||
|
dic[str(cur_id) + "_sp"] = 'sp'
|
||
|
cur_id += 1
|
||
|
add_sp_count += 1
|
||
|
idx, wrd = key.split('_')
|
||
|
dic[str(cur_id) + "_" + wrd] = tp_word2phns[key]
|
||
|
cur_id += 1
|
||
|
|
||
|
if add_sp_count + 1 == sp_count:
|
||
|
dic[str(cur_id) + "_sp"] = 'sp'
|
||
|
add_sp_count += 1
|
||
|
|
||
|
assert add_sp_count == sp_count, "sp are not added in dic"
|
||
|
return dic
|
||
|
|
||
|
|
||
|
def get_max_idx(dic):
|
||
|
return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1]
|
||
|
|
||
|
|
||
|
def get_phns_and_spans(wav_path: str,
|
||
|
old_str: str="",
|
||
|
new_str: str="",
|
||
|
source_lang: str="english",
|
||
|
target_lang: str="english"):
|
||
|
is_append = (old_str == new_str[:len(old_str)])
|
||
|
old_phns, mfa_start, mfa_end = [], [], []
|
||
|
# source
|
||
|
if source_lang == "english":
|
||
|
intervals, word2phns = alignment(wav_path, old_str)
|
||
|
elif source_lang == "chinese":
|
||
|
intervals, word2phns = alignment_zh(wav_path, old_str)
|
||
|
_, tp_word2phns = words2phns_zh(old_str)
|
||
|
|
||
|
for key, value in tp_word2phns.items():
|
||
|
idx, wrd = key.split('_')
|
||
|
cur_val = " ".join(value)
|
||
|
tp_word2phns[key] = cur_val
|
||
|
|
||
|
word2phns = recover_dict(word2phns, tp_word2phns)
|
||
|
else:
|
||
|
assert source_lang == "chinese" or source_lang == "english", \
|
||
|
"source_lang is wrong..."
|
||
|
|
||
|
for item in intervals:
|
||
|
old_phns.append(item[0])
|
||
|
mfa_start.append(float(item[1]))
|
||
|
mfa_end.append(float(item[2]))
|
||
|
# target
|
||
|
if is_append and (source_lang != target_lang):
|
||
|
cross_lingual_clone = True
|
||
|
else:
|
||
|
cross_lingual_clone = False
|
||
|
|
||
|
if cross_lingual_clone:
|
||
|
str_origin = new_str[:len(old_str)]
|
||
|
str_append = new_str[len(old_str):]
|
||
|
|
||
|
if target_lang == "chinese":
|
||
|
phns_origin, origin_word2phns = words2phns(str_origin)
|
||
|
phns_append, append_word2phns_tmp = words2phns_zh(str_append)
|
||
|
|
||
|
elif target_lang == "english":
|
||
|
# 原始句子
|
||
|
phns_origin, origin_word2phns = words2phns_zh(str_origin)
|
||
|
# clone 句子
|
||
|
phns_append, append_word2phns_tmp = words2phns(str_append)
|
||
|
else:
|
||
|
assert target_lang == "chinese" or target_lang == "english", \
|
||
|
"cloning is not support for this language, please check it."
|
||
|
|
||
|
new_phns = phns_origin + phns_append
|
||
|
|
||
|
append_word2phns = {}
|
||
|
length = len(origin_word2phns)
|
||
|
for key, value in append_word2phns_tmp.items():
|
||
|
idx, wrd = key.split('_')
|
||
|
append_word2phns[str(int(idx) + length) + '_' + wrd] = value
|
||
|
new_word2phns = origin_word2phns.copy()
|
||
|
new_word2phns.update(append_word2phns)
|
||
|
|
||
|
else:
|
||
|
if source_lang == target_lang and target_lang == "english":
|
||
|
new_phns, new_word2phns = words2phns(new_str)
|
||
|
elif source_lang == target_lang and target_lang == "chinese":
|
||
|
new_phns, new_word2phns = words2phns_zh(new_str)
|
||
|
else:
|
||
|
assert source_lang == target_lang, \
|
||
|
"source language is not same with target language..."
|
||
|
|
||
|
span_to_repl = [0, len(old_phns) - 1]
|
||
|
span_to_add = [0, len(new_phns) - 1]
|
||
|
left_idx = 0
|
||
|
new_phns_left = []
|
||
|
sp_count = 0
|
||
|
# find the left different index
|
||
|
for key in word2phns.keys():
|
||
|
idx, wrd = key.split('_')
|
||
|
if wrd == 'sp':
|
||
|
sp_count += 1
|
||
|
new_phns_left.append('sp')
|
||
|
else:
|
||
|
idx = str(int(idx) - sp_count)
|
||
|
if idx + '_' + wrd in new_word2phns:
|
||
|
left_idx += len(new_word2phns[idx + '_' + wrd])
|
||
|
new_phns_left.extend(word2phns[key].split())
|
||
|
else:
|
||
|
span_to_repl[0] = len(new_phns_left)
|
||
|
span_to_add[0] = len(new_phns_left)
|
||
|
break
|
||
|
|
||
|
# reverse word2phns and new_word2phns
|
||
|
right_idx = 0
|
||
|
new_phns_right = []
|
||
|
sp_count = 0
|
||
|
word2phns_max_idx = get_max_idx(word2phns)
|
||
|
new_word2phns_max_idx = get_max_idx(new_word2phns)
|
||
|
new_phns_mid = []
|
||
|
if is_append:
|
||
|
new_phns_right = []
|
||
|
new_phns_mid = new_phns[left_idx:]
|
||
|
span_to_repl[0] = len(new_phns_left)
|
||
|
span_to_add[0] = len(new_phns_left)
|
||
|
span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
|
||
|
span_to_repl[1] = len(old_phns) - len(new_phns_right)
|
||
|
# speech edit
|
||
|
else:
|
||
|
for key in list(word2phns.keys())[::-1]:
|
||
|
idx, wrd = key.split('_')
|
||
|
if wrd == 'sp':
|
||
|
sp_count += 1
|
||
|
new_phns_right = ['sp'] + new_phns_right
|
||
|
else:
|
||
|
idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx)
|
||
|
- sp_count))
|
||
|
if idx + '_' + wrd in new_word2phns:
|
||
|
right_idx -= len(new_word2phns[idx + '_' + wrd])
|
||
|
new_phns_right = word2phns[key].split() + new_phns_right
|
||
|
else:
|
||
|
span_to_repl[1] = len(old_phns) - len(new_phns_right)
|
||
|
new_phns_mid = new_phns[left_idx:right_idx]
|
||
|
span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
|
||
|
if len(new_phns_mid) == 0:
|
||
|
span_to_add[1] = min(span_to_add[1] + 1, len(new_phns))
|
||
|
span_to_add[0] = max(0, span_to_add[0] - 1)
|
||
|
span_to_repl[0] = max(0, span_to_repl[0] - 1)
|
||
|
span_to_repl[1] = min(span_to_repl[1] + 1,
|
||
|
len(old_phns))
|
||
|
break
|
||
|
new_phns = new_phns_left + new_phns_mid + new_phns_right
|
||
|
'''
|
||
|
For that reason cover should not be given.
|
||
|
For that reason cover is impossible to be given.
|
||
|
span_to_repl: [17, 23] "should not"
|
||
|
span_to_add: [17, 30] "is impossible to"
|
||
|
'''
|
||
|
return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add
|
||
|
|
||
|
|
||
|
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
|
||
|
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
|
||
|
def get_dur_adj_factor(orig_dur: List[int],
|
||
|
pred_dur: List[int],
|
||
|
phns: List[str]):
|
||
|
length = 0
|
||
|
factor_list = []
|
||
|
for orig, pred, phn in zip(orig_dur, pred_dur, phns):
|
||
|
if pred == 0 or phn == 'sp':
|
||
|
continue
|
||
|
else:
|
||
|
factor_list.append(orig / pred)
|
||
|
factor_list = np.array(factor_list)
|
||
|
factor_list.sort()
|
||
|
if len(factor_list) < 5:
|
||
|
return 1
|
||
|
length = 2
|
||
|
avg = np.average(factor_list[length:-length])
|
||
|
return avg
|
||
|
|
||
|
|
||
|
def prep_feats_with_dur(wav_path: str,
|
||
|
mlm_model: nn.Layer,
|
||
|
source_lang: str="English",
|
||
|
target_lang: str="English",
|
||
|
old_str: str="",
|
||
|
new_str: str="",
|
||
|
mask_reconstruct: bool=False,
|
||
|
duration_adjust: bool=True,
|
||
|
start_end_sp: bool=False,
|
||
|
fs: int=24000,
|
||
|
hop_length: int=300):
|
||
|
'''
|
||
|
Returns:
|
||
|
np.ndarray: new wav, replace the part to be edited in original wav with 0
|
||
|
List[str]: new phones
|
||
|
List[float]: mfa start of new wav
|
||
|
List[float]: mfa end of new wav
|
||
|
List[int]: masked mel boundary of original wav
|
||
|
List[int]: masked mel boundary of new wav
|
||
|
'''
|
||
|
wav_org, _ = librosa.load(wav_path, sr=fs)
|
||
|
|
||
|
mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans(
|
||
|
wav_path=wav_path,
|
||
|
old_str=old_str,
|
||
|
new_str=new_str,
|
||
|
source_lang=source_lang,
|
||
|
target_lang=target_lang)
|
||
|
|
||
|
if start_end_sp:
|
||
|
if new_phns[-1] != 'sp':
|
||
|
new_phns = new_phns + ['sp']
|
||
|
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
|
||
|
if target_lang == "english" or target_lang == "chinese":
|
||
|
old_durs = eval_durs(old_phns, target_lang=source_lang)
|
||
|
else:
|
||
|
assert target_lang == "chinese" or target_lang == "english", \
|
||
|
"calculate duration_predict is not support for this language..."
|
||
|
|
||
|
orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)]
|
||
|
if '[MASK]' in new_str:
|
||
|
new_phns = old_phns
|
||
|
span_to_add = span_to_repl
|
||
|
d_factor_left = get_dur_adj_factor(
|
||
|
orig_dur=orig_old_durs[:span_to_repl[0]],
|
||
|
pred_dur=old_durs[:span_to_repl[0]],
|
||
|
phns=old_phns[:span_to_repl[0]])
|
||
|
d_factor_right = get_dur_adj_factor(
|
||
|
orig_dur=orig_old_durs[span_to_repl[1]:],
|
||
|
pred_dur=old_durs[span_to_repl[1]:],
|
||
|
phns=old_phns[span_to_repl[1]:])
|
||
|
d_factor = (d_factor_left + d_factor_right) / 2
|
||
|
new_durs_adjusted = [d_factor * i for i in old_durs]
|
||
|
else:
|
||
|
if duration_adjust:
|
||
|
d_factor = get_dur_adj_factor(
|
||
|
orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
|
||
|
d_factor = d_factor * 1.25
|
||
|
else:
|
||
|
d_factor = 1
|
||
|
|
||
|
if target_lang == "english" or target_lang == "chinese":
|
||
|
new_durs = eval_durs(new_phns, target_lang=target_lang)
|
||
|
else:
|
||
|
assert target_lang == "chinese" or target_lang == "english", \
|
||
|
"calculate duration_predict is not support for this language..."
|
||
|
|
||
|
new_durs_adjusted = [d_factor * i for i in new_durs]
|
||
|
|
||
|
new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]])
|
||
|
old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]])
|
||
|
dur_offset = new_span_dur_sum - old_span_dur_sum
|
||
|
new_mfa_start = mfa_start[:span_to_repl[0]]
|
||
|
new_mfa_end = mfa_end[:span_to_repl[0]]
|
||
|
for i in new_durs_adjusted[span_to_add[0]:span_to_add[1]]:
|
||
|
if len(new_mfa_end) == 0:
|
||
|
new_mfa_start.append(0)
|
||
|
new_mfa_end.append(i)
|
||
|
else:
|
||
|
new_mfa_start.append(new_mfa_end[-1])
|
||
|
new_mfa_end.append(new_mfa_end[-1] + i)
|
||
|
new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]]
|
||
|
new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]]
|
||
|
|
||
|
# 3. get new wav
|
||
|
# 在原始句子后拼接
|
||
|
if span_to_repl[0] >= len(mfa_start):
|
||
|
left_idx = len(wav_org)
|
||
|
right_idx = left_idx
|
||
|
# 在原始句子中间替换
|
||
|
else:
|
||
|
left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs))
|
||
|
right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs))
|
||
|
blank_wav = np.zeros(
|
||
|
(int(np.ceil(new_span_dur_sum * fs)), ), dtype=wav_org.dtype)
|
||
|
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
|
||
|
new_wav = np.concatenate(
|
||
|
[wav_org[:left_idx], blank_wav, wav_org[right_idx:]])
|
||
|
|
||
|
# 4. get old and new mel span to be mask
|
||
|
# [92, 92]
|
||
|
|
||
|
old_span_bdy, mfa_start, mfa_end = get_masked_mel_bdy(
|
||
|
mfa_start=mfa_start,
|
||
|
mfa_end=mfa_end,
|
||
|
fs=fs,
|
||
|
hop_length=hop_length,
|
||
|
span_to_repl=span_to_repl)
|
||
|
# [92, 174]
|
||
|
# new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别
|
||
|
new_span_bdy, new_mfa_start, new_mfa_end = get_masked_mel_bdy(
|
||
|
mfa_start=new_mfa_start,
|
||
|
mfa_end=new_mfa_end,
|
||
|
fs=fs,
|
||
|
hop_length=hop_length,
|
||
|
span_to_repl=span_to_add)
|
||
|
|
||
|
# old_span_bdy, new_span_bdy 是帧级别的范围
|
||
|
return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy
|
||
|
|
||
|
|
||
|
def prep_feats(mlm_model: nn.Layer,
|
||
|
wav_path: str,
|
||
|
source_lang: str="english",
|
||
|
target_lang: str="english",
|
||
|
old_str: str="",
|
||
|
new_str: str="",
|
||
|
duration_adjust: bool=True,
|
||
|
start_end_sp: bool=False,
|
||
|
mask_reconstruct: bool=False,
|
||
|
fs: int=24000,
|
||
|
hop_length: int=300,
|
||
|
token_list: List[str]=[]):
|
||
|
wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur(
|
||
|
source_lang=source_lang,
|
||
|
target_lang=target_lang,
|
||
|
mlm_model=mlm_model,
|
||
|
old_str=old_str,
|
||
|
new_str=new_str,
|
||
|
wav_path=wav_path,
|
||
|
duration_adjust=duration_adjust,
|
||
|
start_end_sp=start_end_sp,
|
||
|
mask_reconstruct=mask_reconstruct,
|
||
|
fs=fs,
|
||
|
hop_length=hop_length)
|
||
|
|
||
|
token_to_id = {item: i for i, item in enumerate(token_list)}
|
||
|
text = np.array(
|
||
|
list(map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns)))
|
||
|
span_bdy = np.array(new_span_bdy)
|
||
|
|
||
|
batch = [('1', {
|
||
|
"speech": wav,
|
||
|
"align_start": mfa_start,
|
||
|
"align_end": mfa_end,
|
||
|
"text": text,
|
||
|
"span_bdy": span_bdy
|
||
|
})]
|
||
|
|
||
|
return batch, old_span_bdy, new_span_bdy
|
||
|
|
||
|
|
||
|
def decode_with_model(mlm_model: nn.Layer,
|
||
|
collate_fn,
|
||
|
wav_path: str,
|
||
|
source_lang: str="english",
|
||
|
target_lang: str="english",
|
||
|
old_str: str="",
|
||
|
new_str: str="",
|
||
|
use_teacher_forcing: bool=False,
|
||
|
duration_adjust: bool=True,
|
||
|
start_end_sp: bool=False,
|
||
|
fs: int=24000,
|
||
|
hop_length: int=300,
|
||
|
token_list: List[str]=[]):
|
||
|
batch, old_span_bdy, new_span_bdy = prep_feats(
|
||
|
source_lang=source_lang,
|
||
|
target_lang=target_lang,
|
||
|
mlm_model=mlm_model,
|
||
|
wav_path=wav_path,
|
||
|
old_str=old_str,
|
||
|
new_str=new_str,
|
||
|
duration_adjust=duration_adjust,
|
||
|
start_end_sp=start_end_sp,
|
||
|
fs=fs,
|
||
|
hop_length=hop_length,
|
||
|
token_list=token_list)
|
||
|
|
||
|
feats = collate_fn(batch)[1]
|
||
|
|
||
|
if 'text_masked_pos' in feats.keys():
|
||
|
feats.pop('text_masked_pos')
|
||
|
|
||
|
output = mlm_model.inference(
|
||
|
text=feats['text'],
|
||
|
speech=feats['speech'],
|
||
|
masked_pos=feats['masked_pos'],
|
||
|
speech_mask=feats['speech_mask'],
|
||
|
text_mask=feats['text_mask'],
|
||
|
speech_seg_pos=feats['speech_seg_pos'],
|
||
|
text_seg_pos=feats['text_seg_pos'],
|
||
|
span_bdy=new_span_bdy,
|
||
|
use_teacher_forcing=use_teacher_forcing)
|
||
|
|
||
|
# 拼接音频
|
||
|
output_feat = paddle.concat(x=output, axis=0)
|
||
|
wav_org, _ = librosa.load(wav_path, sr=fs)
|
||
|
return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
|
||
|
|
||
|
|
||
|
def get_mlm_output(wav_path: str,
|
||
|
model_name: str="paddle_checkpoint_en",
|
||
|
source_lang: str="english",
|
||
|
target_lang: str="english",
|
||
|
old_str: str="",
|
||
|
new_str: str="",
|
||
|
use_teacher_forcing: bool=False,
|
||
|
duration_adjust: bool=True,
|
||
|
start_end_sp: bool=False):
|
||
|
mlm_model, train_conf = load_model(model_name)
|
||
|
mlm_model.eval()
|
||
|
|
||
|
collate_fn = build_mlm_collate_fn(
|
||
|
sr=train_conf.feats_extract_conf['fs'],
|
||
|
n_fft=train_conf.feats_extract_conf['n_fft'],
|
||
|
hop_length=train_conf.feats_extract_conf['hop_length'],
|
||
|
win_length=train_conf.feats_extract_conf['win_length'],
|
||
|
n_mels=train_conf.feats_extract_conf['n_mels'],
|
||
|
fmin=train_conf.feats_extract_conf['fmin'],
|
||
|
fmax=train_conf.feats_extract_conf['fmax'],
|
||
|
mlm_prob=train_conf['mlm_prob'],
|
||
|
mean_phn_span=train_conf['mean_phn_span'],
|
||
|
seg_emb=train_conf.encoder_conf['input_layer'] == 'sega_mlm')
|
||
|
|
||
|
return decode_with_model(
|
||
|
source_lang=source_lang,
|
||
|
target_lang=target_lang,
|
||
|
mlm_model=mlm_model,
|
||
|
collate_fn=collate_fn,
|
||
|
wav_path=wav_path,
|
||
|
old_str=old_str,
|
||
|
new_str=new_str,
|
||
|
use_teacher_forcing=use_teacher_forcing,
|
||
|
duration_adjust=duration_adjust,
|
||
|
start_end_sp=start_end_sp,
|
||
|
fs=train_conf.feats_extract_conf['fs'],
|
||
|
hop_length=train_conf.feats_extract_conf['hop_length'],
|
||
|
token_list=train_conf.token_list)
|
||
|
|
||
|
|
||
|
def evaluate(uid: str,
|
||
|
source_lang: str="english",
|
||
|
target_lang: str="english",
|
||
|
prefix: os.PathLike="./prompt/dev/",
|
||
|
model_name: str="paddle_checkpoint_en",
|
||
|
new_str: str="",
|
||
|
prompt_decoding: bool=False,
|
||
|
task_name: str=None):
|
||
|
|
||
|
# get origin text and path of origin wav
|
||
|
old_str, wav_path = read_data(uid=uid, prefix=prefix)
|
||
|
|
||
|
if task_name == 'edit':
|
||
|
new_str = new_str
|
||
|
elif task_name == 'synthesize':
|
||
|
new_str = old_str + new_str
|
||
|
else:
|
||
|
new_str = old_str + ' '.join([ch for ch in new_str if is_chinese(ch)])
|
||
|
|
||
|
print('new_str is ', new_str)
|
||
|
|
||
|
results_dict = get_wav(
|
||
|
source_lang=source_lang,
|
||
|
target_lang=target_lang,
|
||
|
model_name=model_name,
|
||
|
wav_path=wav_path,
|
||
|
old_str=old_str,
|
||
|
new_str=new_str)
|
||
|
return results_dict
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
# parse config and args
|
||
|
args = parse_args()
|
||
|
|
||
|
data_dict = evaluate(
|
||
|
uid=args.uid,
|
||
|
source_lang=args.source_lang,
|
||
|
target_lang=args.target_lang,
|
||
|
prefix=args.prefix,
|
||
|
model_name=args.model_name,
|
||
|
new_str=args.new_str,
|
||
|
task_name=args.task_name)
|
||
|
sf.write(args.output_name, data_dict['output'], samplerate=24000)
|
||
|
print("finished...")
|