code format, test=doc

pull/1688/head
lym0302 3 years ago
parent 4b111146dc
commit 9d0224460b

@ -1,6 +1,6 @@
model_path=~/.paddlespeech/models/
am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0/ ## fastspeech2_c
voc_model_dir=$model_path/mb_melgan_csmsc-zh/mb_melgan_csmsc_ckpt_0.1.1/ ## mb_melgan
am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0/
voc_model_dir=$model_path/mb_melgan_csmsc-zh/mb_melgan_csmsc_ckpt_0.1.1/
testdata=../../../../t2s/exps/csmsc_test.txt
# get am file
@ -33,9 +33,13 @@ done
# run test
# am can choose fastspeech2_csmsc or fastspeech2-C_csmsc, where fastspeech2-C_csmsc supports streaming inference.
# am can choose fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc, where fastspeech2_cnndecoder_csmsc supports streaming inference.
# voc can choose hifigan_csmsc and mb_melgan_csmsc, They can both support streaming inference.
python test_online_tts.py --am fastspeech2-C_csmsc \
# When am is fastspeech2_cnndecoder_csmsc and am_pad is set to 12, there is no diff between streaming and non-streaming inference results.
# When voc is mb_melgan_csmsc and voc_pad is set to 14, there is no diff between streaming and non-streaming inference results.
# When voc is hifigan_csmsc and voc_pad is set to 20, there is no diff between streaming and non-streaming inference results.
python test_online_tts.py --am fastspeech2_cnndecoder_csmsc \
--am_config $am_model_dir/$am_config_file \
--am_ckpt $am_model_dir/$am_ckpt_file \
--am_stat $am_model_dir/$am_stat_file \

@ -34,8 +34,8 @@ from paddlespeech.t2s.utils import str2bool
mel_streaming = None
wav_streaming = None
stream_first_time = 0.0
voc_stream_st = 0.0
streaming_first_time = 0.0
streaming_voc_st = 0.0
sample_rate = 0
@ -65,7 +65,7 @@ def get_chunks(data, block_size, pad_size, step):
return chunks
def get_stream_am_inference(args, am_config):
def get_streaming_am_inference(args, am_config):
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
@ -99,8 +99,8 @@ def init(args):
frontend = get_frontend(args)
# acoustic model
if args.am == 'fastspeech2-C_csmsc':
am, am_mu, am_std = get_stream_am_inference(args, am_config)
if args.am == 'fastspeech2_cnndecoder_csmsc':
am, am_mu, am_std = get_streaming_am_inference(args, am_config)
am_infer_info = [am, am_mu, am_std, am_config]
else:
am_inference, am_name, am_dataset = get_am_inference(args, am_config)
@ -139,7 +139,7 @@ def get_phone(args, frontend, sentence, merge_sentences, get_tone_ids):
# 生成完整的mel
def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids):
# 如果是支持流式的AM模型
if args.am == 'fastspeech2-C_csmsc':
if args.am == 'fastspeech2_cnndecoder_csmsc':
am, am_mu, am_std, am_config = am_infer_info
orig_hs, h_masks = am.encoder_infer(part_phone_ids)
if args.am_streaming:
@ -183,9 +183,9 @@ def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids):
@paddle.no_grad()
def stream_voc_infer(args, voc_infer_info, mel_len):
def streaming_voc_infer(args, voc_infer_info, mel_len):
global mel_streaming
global stream_first_time
global streaming_first_time
global wav_streaming
voc_inference, voc_config = voc_infer_info
block = args.voc_block
@ -203,7 +203,7 @@ def stream_voc_infer(args, voc_infer_info, mel_len):
while valid_end <= mel_len:
sub_wav = voc_inference(mel_chunk)
if flag == 1:
stream_first_time = time.time()
streaming_first_time = time.time()
flag = 0
# get valid wav
@ -233,7 +233,7 @@ def stream_voc_infer(args, voc_infer_info, mel_len):
@paddle.no_grad()
# 非流式AM / 流式AM + 非流式Voc
def am_nostream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
def am_nonstreaming_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
part_tone_ids):
mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids)
am_infer_time = time.time()
@ -248,10 +248,10 @@ def am_nostream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
@paddle.no_grad()
# 非流式AM + 流式Voc
def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
part_tone_ids):
def nonstreaming_am_streaming_voc(args, am_infer_info, voc_infer_info,
part_phone_ids, part_tone_ids):
global mel_streaming
global stream_first_time
global streaming_first_time
global wav_streaming
mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids)
@ -260,8 +260,8 @@ def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
# voc streaming
mel_streaming = mel
mel_len = mel.shape[0]
stream_voc_infer(args, voc_infer_info, mel_len)
first_response_time = stream_first_time
streaming_voc_infer(args, voc_infer_info, mel_len)
first_response_time = streaming_first_time
wav = wav_streaming
final_response_time = time.time()
voc_infer_time = final_response_time
@ -271,12 +271,12 @@ def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
@paddle.no_grad()
# 流式AM + 流式 Voc
def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
part_tone_ids):
def streaming_am_streaming_voc(args, am_infer_info, voc_infer_info,
part_phone_ids, part_tone_ids):
global mel_streaming
global stream_first_time
global streaming_first_time
global wav_streaming
global voc_stream_st
global streaming_voc_st
mel_streaming = None
#用来表示开启流式voc的线程
flag = 1
@ -311,15 +311,16 @@ def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
if flag and mel_streaming.shape[0] > args.voc_block + args.voc_pad:
t = threading.Thread(
target=stream_voc_infer, args=(args, voc_infer_info, mel_len, ))
target=streaming_voc_infer,
args=(args, voc_infer_info, mel_len, ))
t.start()
voc_stream_st = time.time()
streaming_voc_st = time.time()
flag = 0
t.join()
final_response_time = time.time()
voc_infer_time = final_response_time
first_response_time = stream_first_time
first_response_time = streaming_first_time
wav = wav_streaming
return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav
@ -337,11 +338,11 @@ def warm_up(args, logger, frontend, am_infer_info, voc_infer_info):
if args.voc_streaming:
if args.am_streaming:
infer_func = stream_am_stream_voc
infer_func = streaming_am_streaming_voc
else:
infer_func = nostream_am_stream_voc
infer_func = nonstreaming_am_streaming_voc
else:
infer_func = am_nostream_voc
infer_func = am_nonstreaming_voc
merge_sentences = True
get_tone_ids = False
@ -376,11 +377,11 @@ def evaluate(args, logger, frontend, am_infer_info, voc_infer_info):
# choose infer function
if args.voc_streaming:
if args.am_streaming:
infer_func = stream_am_stream_voc
infer_func = streaming_am_streaming_voc
else:
infer_func = nostream_am_stream_voc
infer_func = nonstreaming_am_streaming_voc
else:
infer_func = am_nostream_voc
infer_func = am_nonstreaming_voc
final_up_duration = 0.0
sentence_count = 0
@ -410,7 +411,7 @@ def evaluate(args, logger, frontend, am_infer_info, voc_infer_info):
args, am_infer_info, voc_infer_info, part_phone_ids, part_tone_ids)
am_time = am_infer_time - am_st
if args.voc_streaming and args.am_streaming:
voc_time = voc_infer_time - voc_stream_st
voc_time = voc_infer_time - streaming_voc_st
else:
voc_time = voc_infer_time - am_infer_time
@ -482,8 +483,8 @@ def parse_args():
'--am',
type=str,
default='fastspeech2_csmsc',
choices=['fastspeech2_csmsc', 'fastspeech2-C_csmsc'],
help='Choose acoustic model type of tts task. where fastspeech2-C_csmsc supports streaming inference'
choices=['fastspeech2_csmsc', 'fastspeech2_cnndecoder_csmsc'],
help='Choose acoustic model type of tts task. where fastspeech2_cnndecoder_csmsc supports streaming inference'
)
parser.add_argument(
@ -576,7 +577,7 @@ def main():
args = parse_args()
paddle.set_device(args.device)
if args.am_streaming:
assert (args.am == 'fastspeech2-C_csmsc')
assert (args.am == 'fastspeech2_cnndecoder_csmsc')
logger = logging.getLogger()
fhandler = logging.FileHandler(filename=args.log_file, mode='w')

Loading…
Cancel
Save