|
|
|
@ -34,8 +34,8 @@ from paddlespeech.t2s.utils import str2bool
|
|
|
|
|
|
|
|
|
|
mel_streaming = None
|
|
|
|
|
wav_streaming = None
|
|
|
|
|
stream_first_time = 0.0
|
|
|
|
|
voc_stream_st = 0.0
|
|
|
|
|
streaming_first_time = 0.0
|
|
|
|
|
streaming_voc_st = 0.0
|
|
|
|
|
sample_rate = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -65,7 +65,7 @@ def get_chunks(data, block_size, pad_size, step):
|
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_stream_am_inference(args, am_config):
|
|
|
|
|
def get_streaming_am_inference(args, am_config):
|
|
|
|
|
with open(args.phones_dict, "r") as f:
|
|
|
|
|
phn_id = [line.strip().split() for line in f.readlines()]
|
|
|
|
|
vocab_size = len(phn_id)
|
|
|
|
@ -99,8 +99,8 @@ def init(args):
|
|
|
|
|
frontend = get_frontend(args)
|
|
|
|
|
|
|
|
|
|
# acoustic model
|
|
|
|
|
if args.am == 'fastspeech2-C_csmsc':
|
|
|
|
|
am, am_mu, am_std = get_stream_am_inference(args, am_config)
|
|
|
|
|
if args.am == 'fastspeech2_cnndecoder_csmsc':
|
|
|
|
|
am, am_mu, am_std = get_streaming_am_inference(args, am_config)
|
|
|
|
|
am_infer_info = [am, am_mu, am_std, am_config]
|
|
|
|
|
else:
|
|
|
|
|
am_inference, am_name, am_dataset = get_am_inference(args, am_config)
|
|
|
|
@ -139,7 +139,7 @@ def get_phone(args, frontend, sentence, merge_sentences, get_tone_ids):
|
|
|
|
|
# 生成完整的mel
|
|
|
|
|
def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids):
|
|
|
|
|
# 如果是支持流式的AM模型
|
|
|
|
|
if args.am == 'fastspeech2-C_csmsc':
|
|
|
|
|
if args.am == 'fastspeech2_cnndecoder_csmsc':
|
|
|
|
|
am, am_mu, am_std, am_config = am_infer_info
|
|
|
|
|
orig_hs, h_masks = am.encoder_infer(part_phone_ids)
|
|
|
|
|
if args.am_streaming:
|
|
|
|
@ -183,9 +183,9 @@ def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def stream_voc_infer(args, voc_infer_info, mel_len):
|
|
|
|
|
def streaming_voc_infer(args, voc_infer_info, mel_len):
|
|
|
|
|
global mel_streaming
|
|
|
|
|
global stream_first_time
|
|
|
|
|
global streaming_first_time
|
|
|
|
|
global wav_streaming
|
|
|
|
|
voc_inference, voc_config = voc_infer_info
|
|
|
|
|
block = args.voc_block
|
|
|
|
@ -203,7 +203,7 @@ def stream_voc_infer(args, voc_infer_info, mel_len):
|
|
|
|
|
while valid_end <= mel_len:
|
|
|
|
|
sub_wav = voc_inference(mel_chunk)
|
|
|
|
|
if flag == 1:
|
|
|
|
|
stream_first_time = time.time()
|
|
|
|
|
streaming_first_time = time.time()
|
|
|
|
|
flag = 0
|
|
|
|
|
|
|
|
|
|
# get valid wav
|
|
|
|
@ -233,8 +233,8 @@ def stream_voc_infer(args, voc_infer_info, mel_len):
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
# 非流式AM / 流式AM + 非流式Voc
|
|
|
|
|
def am_nostream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
|
|
|
|
|
part_tone_ids):
|
|
|
|
|
def am_nonstreaming_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
|
|
|
|
|
part_tone_ids):
|
|
|
|
|
mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids)
|
|
|
|
|
am_infer_time = time.time()
|
|
|
|
|
voc_inference, voc_config = voc_infer_info
|
|
|
|
@ -248,10 +248,10 @@ def am_nostream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
# 非流式AM + 流式Voc
|
|
|
|
|
def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
|
|
|
|
|
part_tone_ids):
|
|
|
|
|
def nonstreaming_am_streaming_voc(args, am_infer_info, voc_infer_info,
|
|
|
|
|
part_phone_ids, part_tone_ids):
|
|
|
|
|
global mel_streaming
|
|
|
|
|
global stream_first_time
|
|
|
|
|
global streaming_first_time
|
|
|
|
|
global wav_streaming
|
|
|
|
|
|
|
|
|
|
mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids)
|
|
|
|
@ -260,8 +260,8 @@ def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
|
|
|
|
|
# voc streaming
|
|
|
|
|
mel_streaming = mel
|
|
|
|
|
mel_len = mel.shape[0]
|
|
|
|
|
stream_voc_infer(args, voc_infer_info, mel_len)
|
|
|
|
|
first_response_time = stream_first_time
|
|
|
|
|
streaming_voc_infer(args, voc_infer_info, mel_len)
|
|
|
|
|
first_response_time = streaming_first_time
|
|
|
|
|
wav = wav_streaming
|
|
|
|
|
final_response_time = time.time()
|
|
|
|
|
voc_infer_time = final_response_time
|
|
|
|
@ -271,12 +271,12 @@ def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
# 流式AM + 流式 Voc
|
|
|
|
|
def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
|
|
|
|
|
part_tone_ids):
|
|
|
|
|
def streaming_am_streaming_voc(args, am_infer_info, voc_infer_info,
|
|
|
|
|
part_phone_ids, part_tone_ids):
|
|
|
|
|
global mel_streaming
|
|
|
|
|
global stream_first_time
|
|
|
|
|
global streaming_first_time
|
|
|
|
|
global wav_streaming
|
|
|
|
|
global voc_stream_st
|
|
|
|
|
global streaming_voc_st
|
|
|
|
|
mel_streaming = None
|
|
|
|
|
#用来表示开启流式voc的线程
|
|
|
|
|
flag = 1
|
|
|
|
@ -311,15 +311,16 @@ def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
|
|
|
|
|
|
|
|
|
|
if flag and mel_streaming.shape[0] > args.voc_block + args.voc_pad:
|
|
|
|
|
t = threading.Thread(
|
|
|
|
|
target=stream_voc_infer, args=(args, voc_infer_info, mel_len, ))
|
|
|
|
|
target=streaming_voc_infer,
|
|
|
|
|
args=(args, voc_infer_info, mel_len, ))
|
|
|
|
|
t.start()
|
|
|
|
|
voc_stream_st = time.time()
|
|
|
|
|
streaming_voc_st = time.time()
|
|
|
|
|
flag = 0
|
|
|
|
|
|
|
|
|
|
t.join()
|
|
|
|
|
final_response_time = time.time()
|
|
|
|
|
voc_infer_time = final_response_time
|
|
|
|
|
first_response_time = stream_first_time
|
|
|
|
|
first_response_time = streaming_first_time
|
|
|
|
|
wav = wav_streaming
|
|
|
|
|
|
|
|
|
|
return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav
|
|
|
|
@ -337,11 +338,11 @@ def warm_up(args, logger, frontend, am_infer_info, voc_infer_info):
|
|
|
|
|
|
|
|
|
|
if args.voc_streaming:
|
|
|
|
|
if args.am_streaming:
|
|
|
|
|
infer_func = stream_am_stream_voc
|
|
|
|
|
infer_func = streaming_am_streaming_voc
|
|
|
|
|
else:
|
|
|
|
|
infer_func = nostream_am_stream_voc
|
|
|
|
|
infer_func = nonstreaming_am_streaming_voc
|
|
|
|
|
else:
|
|
|
|
|
infer_func = am_nostream_voc
|
|
|
|
|
infer_func = am_nonstreaming_voc
|
|
|
|
|
|
|
|
|
|
merge_sentences = True
|
|
|
|
|
get_tone_ids = False
|
|
|
|
@ -376,11 +377,11 @@ def evaluate(args, logger, frontend, am_infer_info, voc_infer_info):
|
|
|
|
|
# choose infer function
|
|
|
|
|
if args.voc_streaming:
|
|
|
|
|
if args.am_streaming:
|
|
|
|
|
infer_func = stream_am_stream_voc
|
|
|
|
|
infer_func = streaming_am_streaming_voc
|
|
|
|
|
else:
|
|
|
|
|
infer_func = nostream_am_stream_voc
|
|
|
|
|
infer_func = nonstreaming_am_streaming_voc
|
|
|
|
|
else:
|
|
|
|
|
infer_func = am_nostream_voc
|
|
|
|
|
infer_func = am_nonstreaming_voc
|
|
|
|
|
|
|
|
|
|
final_up_duration = 0.0
|
|
|
|
|
sentence_count = 0
|
|
|
|
@ -410,7 +411,7 @@ def evaluate(args, logger, frontend, am_infer_info, voc_infer_info):
|
|
|
|
|
args, am_infer_info, voc_infer_info, part_phone_ids, part_tone_ids)
|
|
|
|
|
am_time = am_infer_time - am_st
|
|
|
|
|
if args.voc_streaming and args.am_streaming:
|
|
|
|
|
voc_time = voc_infer_time - voc_stream_st
|
|
|
|
|
voc_time = voc_infer_time - streaming_voc_st
|
|
|
|
|
else:
|
|
|
|
|
voc_time = voc_infer_time - am_infer_time
|
|
|
|
|
|
|
|
|
@ -482,8 +483,8 @@ def parse_args():
|
|
|
|
|
'--am',
|
|
|
|
|
type=str,
|
|
|
|
|
default='fastspeech2_csmsc',
|
|
|
|
|
choices=['fastspeech2_csmsc', 'fastspeech2-C_csmsc'],
|
|
|
|
|
help='Choose acoustic model type of tts task. where fastspeech2-C_csmsc supports streaming inference'
|
|
|
|
|
choices=['fastspeech2_csmsc', 'fastspeech2_cnndecoder_csmsc'],
|
|
|
|
|
help='Choose acoustic model type of tts task. where fastspeech2_cnndecoder_csmsc supports streaming inference'
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
@ -576,7 +577,7 @@ def main():
|
|
|
|
|
args = parse_args()
|
|
|
|
|
paddle.set_device(args.device)
|
|
|
|
|
if args.am_streaming:
|
|
|
|
|
assert (args.am == 'fastspeech2-C_csmsc')
|
|
|
|
|
assert (args.am == 'fastspeech2_cnndecoder_csmsc')
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
fhandler = logging.FileHandler(filename=args.log_file, mode='w')
|
|
|
|
|