parent
97965f4c37
commit
9d4161ce5f
@ -0,0 +1,346 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
from paddlespeech.t2s.exps.ernie_sat.align import get_phns_spans
|
||||
from paddlespeech.t2s.exps.ernie_sat.utils import eval_durs
|
||||
from paddlespeech.t2s.exps.ernie_sat.utils import get_dur_adj_factor
|
||||
from paddlespeech.t2s.exps.ernie_sat.utils import get_span_bdy
|
||||
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
|
||||
from paddlespeech.t2s.exps.syn_utils import get_frontend
|
||||
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
|
||||
from paddlespeech.t2s.exps.syn_utils import norm
|
||||
from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _p2id(self, phonemes: List[str]) -> np.ndarray:
|
||||
# replace unk phone with sp
|
||||
phonemes = [
|
||||
phn if phn in vocab_phones else "sp" for phn in phonemes
|
||||
]
|
||||
phone_ids = [vocab_phones[item] for item in phonemes]
|
||||
return np.array(phone_ids, np.int64)
|
||||
|
||||
|
||||
|
||||
def prep_feats_with_dur(wav_path: str,
|
||||
old_str: str='',
|
||||
new_str: str='',
|
||||
source_lang: str='en',
|
||||
target_lang: str='en',
|
||||
duration_adjust: bool=True,
|
||||
fs: int=24000,
|
||||
n_shift: int=300):
|
||||
'''
|
||||
Returns:
|
||||
np.ndarray: new wav, replace the part to be edited in original wav with 0
|
||||
List[str]: new phones
|
||||
List[float]: mfa start of new wav
|
||||
List[float]: mfa end of new wav
|
||||
List[int]: masked mel boundary of original wav
|
||||
List[int]: masked mel boundary of new wav
|
||||
'''
|
||||
wav_org, _ = librosa.load(wav_path, sr=fs)
|
||||
phns_spans_outs = get_phns_spans(
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
fs=fs,
|
||||
n_shift=n_shift)
|
||||
|
||||
mfa_start = phns_spans_outs["mfa_start"]
|
||||
mfa_end = phns_spans_outs["mfa_end"]
|
||||
old_phns = phns_spans_outs["old_phns"]
|
||||
new_phns = phns_spans_outs["new_phns"]
|
||||
span_to_repl = phns_spans_outs["span_to_repl"]
|
||||
span_to_add = phns_spans_outs["span_to_add"]
|
||||
|
||||
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
|
||||
if target_lang in {'en', 'zh'}:
|
||||
old_durs = eval_durs(old_phns, target_lang=source_lang)
|
||||
else:
|
||||
assert target_lang in {'en', 'zh'}, \
|
||||
"calculate duration_predict is not support for this language..."
|
||||
|
||||
orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)]
|
||||
|
||||
if duration_adjust:
|
||||
d_factor = get_dur_adj_factor(
|
||||
orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
|
||||
d_factor = d_factor * 1.25
|
||||
else:
|
||||
d_factor = 1
|
||||
|
||||
if target_lang in {'en', 'zh'}:
|
||||
new_durs = eval_durs(new_phns, target_lang=target_lang)
|
||||
else:
|
||||
assert target_lang == "zh" or target_lang == "en", \
|
||||
"calculate duration_predict is not support for this language..."
|
||||
|
||||
# duration 要是整数
|
||||
new_durs_adjusted = [int(np.ceil(d_factor * i)) for i in new_durs]
|
||||
|
||||
new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]])
|
||||
old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]])
|
||||
dur_offset = new_span_dur_sum - old_span_dur_sum
|
||||
new_mfa_start = mfa_start[:span_to_repl[0]]
|
||||
new_mfa_end = mfa_end[:span_to_repl[0]]
|
||||
|
||||
for dur in new_durs_adjusted[span_to_add[0]:span_to_add[1]]:
|
||||
if len(new_mfa_end) == 0:
|
||||
new_mfa_start.append(0)
|
||||
new_mfa_end.append(dur)
|
||||
else:
|
||||
new_mfa_start.append(new_mfa_end[-1])
|
||||
new_mfa_end.append(new_mfa_end[-1] + dur)
|
||||
|
||||
new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]]
|
||||
new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]]
|
||||
|
||||
# 3. get new wav
|
||||
# 在原始句子后拼接
|
||||
if span_to_repl[0] >= len(mfa_start):
|
||||
wav_left_idx = len(wav_org)
|
||||
wav_right_idx = wav_left_idx
|
||||
# 在原始句子中间替换
|
||||
else:
|
||||
wav_left_idx = int(np.floor(mfa_start[span_to_repl[0]] * n_shift))
|
||||
wav_right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * n_shift))
|
||||
blank_wav = np.zeros(
|
||||
(int(np.ceil(new_span_dur_sum * n_shift)), ), dtype=wav_org.dtype)
|
||||
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
|
||||
new_wav = np.concatenate(
|
||||
[wav_org[:wav_left_idx], blank_wav, wav_org[wav_right_idx:]])
|
||||
|
||||
# 音频是正常遮住了
|
||||
sf.write(str("new_wav.wav"), new_wav, samplerate=fs)
|
||||
|
||||
# 4. get old and new mel span to be mask
|
||||
old_span_bdy = get_span_bdy(
|
||||
mfa_start=mfa_start, mfa_end=mfa_end, span_to_repl=span_to_repl)
|
||||
|
||||
new_span_bdy = get_span_bdy(
|
||||
mfa_start=new_mfa_start, mfa_end=new_mfa_end, span_to_repl=span_to_add)
|
||||
|
||||
# old_span_bdy, new_span_bdy 是帧级别的范围
|
||||
outs = {}
|
||||
outs['new_wav'] = new_wav
|
||||
outs['new_phns'] = new_phns
|
||||
outs['new_mfa_start'] = new_mfa_start
|
||||
outs['new_mfa_end'] = new_mfa_end
|
||||
outs['old_span_bdy'] = old_span_bdy
|
||||
outs['new_span_bdy'] = new_span_bdy
|
||||
return outs
|
||||
|
||||
|
||||
|
||||
|
||||
def prep_feats(wav_path: str,
|
||||
old_str: str='',
|
||||
new_str: str='',
|
||||
source_lang: str='en',
|
||||
target_lang: str='en',
|
||||
duration_adjust: bool=True,
|
||||
fs: int=24000,
|
||||
n_shift: int=300):
|
||||
|
||||
outs = prep_feats_with_dur(
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
duration_adjust=duration_adjust,
|
||||
fs=fs,
|
||||
n_shift=n_shift)
|
||||
|
||||
wav_name = os.path.basename(wav_path)
|
||||
utt_id = wav_name.split('.')[0]
|
||||
|
||||
wav = outs['new_wav']
|
||||
phns = outs['new_phns']
|
||||
mfa_start = outs['new_mfa_start']
|
||||
mfa_end = outs['new_mfa_end']
|
||||
old_span_bdy = outs['old_span_bdy']
|
||||
new_span_bdy = outs['new_span_bdy']
|
||||
span_bdy = np.array(new_span_bdy)
|
||||
|
||||
text = _p2id(phns)
|
||||
mel = mel_extractor.get_log_mel_fbank(wav)
|
||||
erniesat_mean, erniesat_std = np.load(erniesat_stat)
|
||||
normed_mel = norm(mel, erniesat_mean, erniesat_std)
|
||||
tmp_name = get_tmp_name(text=old_str)
|
||||
tmpbase = './tmp_dir/' + tmp_name
|
||||
tmpbase = Path(tmpbase)
|
||||
tmpbase.mkdir(parents=True, exist_ok=True)
|
||||
print("tmp_name in synthesize_e2e:",tmp_name)
|
||||
|
||||
mel_path = tmpbase / 'mel.npy'
|
||||
print("mel_path:",mel_path)
|
||||
np.save(mel_path, logmel)
|
||||
durations = [e - s for e, s in zip(mfa_end, mfa_start)]
|
||||
|
||||
datum={
|
||||
"utt_id": utt_id,
|
||||
"spk_id": 0,
|
||||
"text": text,
|
||||
"text_lengths": len(text),
|
||||
"speech_lengths": 115,
|
||||
"durations": durations,
|
||||
"speech": mel_path,
|
||||
"align_start": mfa_start,
|
||||
"align_end": mfa_end,
|
||||
"span_bdy": span_bdy
|
||||
}
|
||||
|
||||
batch = collate_fn([datum])
|
||||
print("batch:",batch)
|
||||
|
||||
return batch, old_span_bdy, new_span_bdy
|
||||
|
||||
|
||||
def decode_with_model(mlm_model: nn.Layer,
|
||||
collate_fn,
|
||||
wav_path: str,
|
||||
old_str: str='',
|
||||
new_str: str='',
|
||||
source_lang: str='en',
|
||||
target_lang: str='en',
|
||||
use_teacher_forcing: bool=False,
|
||||
duration_adjust: bool=True,
|
||||
fs: int=24000,
|
||||
n_shift: int=300,
|
||||
token_list: List[str]=[]):
|
||||
batch, old_span_bdy, new_span_bdy = prep_feats(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
duration_adjust=duration_adjust,
|
||||
fs=fs,
|
||||
n_shift=n_shift,
|
||||
token_list=token_list)
|
||||
|
||||
|
||||
|
||||
feats = collate_fn(batch)[1]
|
||||
|
||||
if 'text_masked_pos' in feats.keys():
|
||||
feats.pop('text_masked_pos')
|
||||
|
||||
output = mlm_model.inference(
|
||||
text=feats['text'],
|
||||
speech=feats['speech'],
|
||||
masked_pos=feats['masked_pos'],
|
||||
speech_mask=feats['speech_mask'],
|
||||
text_mask=feats['text_mask'],
|
||||
speech_seg_pos=feats['speech_seg_pos'],
|
||||
text_seg_pos=feats['text_seg_pos'],
|
||||
span_bdy=new_span_bdy,
|
||||
use_teacher_forcing=use_teacher_forcing)
|
||||
|
||||
# 拼接音频
|
||||
output_feat = paddle.concat(x=output, axis=0)
|
||||
wav_org, _ = librosa.load(wav_path, sr=fs)
|
||||
return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fs = 24000
|
||||
n_shift = 300
|
||||
wav_path = "exp/p243_313.wav"
|
||||
old_str = "For that reason cover should not be given."
|
||||
# for edit
|
||||
# new_str = "for that reason cover is impossible to be given."
|
||||
# for synthesize
|
||||
append_str = "do you love me i love you so much"
|
||||
new_str = old_str + append_str
|
||||
|
||||
'''
|
||||
outs = prep_feats_with_dur(
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
fs=fs,
|
||||
n_shift=n_shift)
|
||||
|
||||
new_wav = outs['new_wav']
|
||||
new_phns = outs['new_phns']
|
||||
new_mfa_start = outs['new_mfa_start']
|
||||
new_mfa_end = outs['new_mfa_end']
|
||||
old_span_bdy = outs['old_span_bdy']
|
||||
new_span_bdy = outs['new_span_bdy']
|
||||
|
||||
print("---------------------------------")
|
||||
|
||||
print("new_wav:", new_wav)
|
||||
print("new_phns:", new_phns)
|
||||
print("new_mfa_start:", new_mfa_start)
|
||||
print("new_mfa_end:", new_mfa_end)
|
||||
print("old_span_bdy:", old_span_bdy)
|
||||
print("new_span_bdy:", new_span_bdy)
|
||||
print("---------------------------------")
|
||||
'''
|
||||
|
||||
erniesat_config = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/local/default.yaml"
|
||||
|
||||
with open(erniesat_config) as f:
|
||||
erniesat_config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
erniesat_stat = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/train/speech_stats.npy"
|
||||
|
||||
# Extractor
|
||||
mel_extractor = LogMelFBank(
|
||||
sr=erniesat_config.fs,
|
||||
n_fft=erniesat_config.n_fft,
|
||||
hop_length=erniesat_config.n_shift,
|
||||
win_length=erniesat_config.win_length,
|
||||
window=erniesat_config.window,
|
||||
n_mels=erniesat_config.n_mels,
|
||||
fmin=erniesat_config.fmin,
|
||||
fmax=erniesat_config.fmax)
|
||||
|
||||
|
||||
|
||||
collate_fn = build_erniesat_collate_fn(
|
||||
mlm_prob=erniesat_config.mlm_prob,
|
||||
mean_phn_span=erniesat_config.mean_phn_span,
|
||||
seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm',
|
||||
text_masking=False)
|
||||
|
||||
phones_dict='/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/phone_id_map.txt'
|
||||
vocab_phones = {}
|
||||
|
||||
with open(phones_dict, 'rt') as f:
|
||||
phn_id = [line.strip().split() for line in f.readlines()]
|
||||
for phn, id in phn_id:
|
||||
vocab_phones[phn] = int(id)
|
||||
|
||||
prep_feats(wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
fs=fs,
|
||||
n_shift=n_shift)
|
||||
|
||||
|
||||
|
@ -0,0 +1,216 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import yaml
|
||||
from yacs.config import CfgNode
|
||||
import hashlib
|
||||
|
||||
|
||||
from paddlespeech.t2s.exps.syn_utils import get_am_inference
|
||||
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
|
||||
|
||||
def _get_user():
|
||||
return os.path.expanduser('~').split('/')[-1]
|
||||
|
||||
def str2md5(string):
|
||||
md5_val = hashlib.md5(string.encode('utf8')).hexdigest()
|
||||
return md5_val
|
||||
|
||||
def get_tmp_name(text:str):
|
||||
return _get_user() + '_' + str(os.getpid()) + '_' + str2md5(text)
|
||||
|
||||
def get_dict(dictfile: str):
|
||||
word2phns_dict = {}
|
||||
with open(dictfile, 'r') as fid:
|
||||
for line in fid:
|
||||
line_lst = line.split()
|
||||
word, phn_lst = line_lst[0], line.split()[1:]
|
||||
if word not in word2phns_dict.keys():
|
||||
word2phns_dict[word] = ' '.join(phn_lst)
|
||||
return word2phns_dict
|
||||
|
||||
|
||||
# 获取需要被 mask 的 mel 帧的范围
|
||||
def get_span_bdy(mfa_start: List[float],
|
||||
mfa_end: List[float],
|
||||
span_to_repl: List[List[int]]):
|
||||
if span_to_repl[0] >= len(mfa_start):
|
||||
span_bdy = [mfa_end[-1], mfa_end[-1]]
|
||||
else:
|
||||
span_bdy = [mfa_start[span_to_repl[0]], mfa_end[span_to_repl[1] - 1]]
|
||||
return span_bdy
|
||||
|
||||
|
||||
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
|
||||
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
|
||||
def get_dur_adj_factor(orig_dur: List[int],
|
||||
pred_dur: List[int],
|
||||
phns: List[str]):
|
||||
length = 0
|
||||
factor_list = []
|
||||
for orig, pred, phn in zip(orig_dur, pred_dur, phns):
|
||||
if pred == 0 or phn == 'sp':
|
||||
continue
|
||||
else:
|
||||
factor_list.append(orig / pred)
|
||||
factor_list = np.array(factor_list)
|
||||
factor_list.sort()
|
||||
if len(factor_list) < 5:
|
||||
return 1
|
||||
length = 2
|
||||
avg = np.average(factor_list[length:-length])
|
||||
return avg
|
||||
|
||||
|
||||
def read_2col_text(path: Union[Path, str]) -> Dict[str, str]:
|
||||
"""Read a text file having 2 column as dict object.
|
||||
|
||||
Examples:
|
||||
wav.scp:
|
||||
key1 /some/path/a.wav
|
||||
key2 /some/path/b.wav
|
||||
|
||||
>>> read_2col_text('wav.scp')
|
||||
{'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'}
|
||||
|
||||
"""
|
||||
|
||||
data = {}
|
||||
with Path(path).open("r", encoding="utf-8") as f:
|
||||
for linenum, line in enumerate(f, 1):
|
||||
sps = line.rstrip().split(maxsplit=1)
|
||||
if len(sps) == 1:
|
||||
k, v = sps[0], ""
|
||||
else:
|
||||
k, v = sps
|
||||
if k in data:
|
||||
raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
|
||||
data[k] = v
|
||||
return data
|
||||
|
||||
|
||||
def load_num_sequence_text(path: Union[Path, str], loader_type: str="csv_int"
|
||||
) -> Dict[str, List[Union[float, int]]]:
|
||||
"""Read a text file indicating sequences of number
|
||||
|
||||
Examples:
|
||||
key1 1 2 3
|
||||
key2 34 5 6
|
||||
|
||||
>>> d = load_num_sequence_text('text')
|
||||
>>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3]))
|
||||
"""
|
||||
if loader_type == "text_int":
|
||||
delimiter = " "
|
||||
dtype = int
|
||||
elif loader_type == "text_float":
|
||||
delimiter = " "
|
||||
dtype = float
|
||||
elif loader_type == "csv_int":
|
||||
delimiter = ","
|
||||
dtype = int
|
||||
elif loader_type == "csv_float":
|
||||
delimiter = ","
|
||||
dtype = float
|
||||
else:
|
||||
raise ValueError(f"Not supported loader_type={loader_type}")
|
||||
|
||||
# path looks like:
|
||||
# utta 1,0
|
||||
# uttb 3,4,5
|
||||
# -> return {'utta': np.ndarray([1, 0]),
|
||||
# 'uttb': np.ndarray([3, 4, 5])}
|
||||
d = read_2column_text(path)
|
||||
# Using for-loop instead of dict-comprehension for debuggability
|
||||
retval = {}
|
||||
for k, v in d.items():
|
||||
try:
|
||||
retval[k] = [dtype(i) for i in v.split(delimiter)]
|
||||
except TypeError:
|
||||
print(f'Error happened with path="{path}", id="{k}", value="{v}"')
|
||||
raise
|
||||
return retval
|
||||
|
||||
|
||||
def is_chinese(ch):
|
||||
if u'\u4e00' <= ch <= u'\u9fff':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def get_voc_out(mel):
|
||||
# vocoder
|
||||
args = parse_args()
|
||||
with open(args.voc_config) as f:
|
||||
voc_config = CfgNode(yaml.safe_load(f))
|
||||
voc_inference = get_voc_inference(
|
||||
voc=args.voc,
|
||||
voc_config=voc_config,
|
||||
voc_ckpt=args.voc_ckpt,
|
||||
voc_stat=args.voc_stat)
|
||||
|
||||
with paddle.no_grad():
|
||||
wav = voc_inference(mel)
|
||||
return np.squeeze(wav)
|
||||
|
||||
|
||||
def eval_durs(phns, target_lang: str='zh', fs: int=24000, n_shift: int=300):
|
||||
|
||||
if target_lang == 'en':
|
||||
am = "fastspeech2_ljspeech"
|
||||
am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml"
|
||||
am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz"
|
||||
am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy"
|
||||
phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt"
|
||||
|
||||
elif target_lang == 'zh':
|
||||
am = "fastspeech2_csmsc"
|
||||
am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml"
|
||||
am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz"
|
||||
am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy"
|
||||
phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt"
|
||||
|
||||
# Init body.
|
||||
with open(am_config) as f:
|
||||
am_config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
am_inference, am = get_am_inference(
|
||||
am=am,
|
||||
am_config=am_config,
|
||||
am_ckpt=am_ckpt,
|
||||
am_stat=am_stat,
|
||||
phones_dict=phones_dict,
|
||||
return_am=True)
|
||||
|
||||
vocab_phones = {}
|
||||
with open(phones_dict, "r") as f:
|
||||
phn_id = [line.strip().split() for line in f.readlines()]
|
||||
for tone, id in phn_id:
|
||||
vocab_phones[tone] = int(id)
|
||||
vocab_size = len(vocab_phones)
|
||||
phonemes = [phn if phn in vocab_phones else "sp" for phn in phns]
|
||||
|
||||
phone_ids = [vocab_phones[item] for item in phonemes]
|
||||
phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64))
|
||||
_, d_outs, _, _ = am.inference(phone_ids)
|
||||
d_outs = d_outs.tolist()
|
||||
return d_outs
|
Loading…
Reference in new issue