code format, test=doc

pull/1688/head
lym0302 3 years ago
parent 4b111146dc
commit 9d0224460b

@ -1,6 +1,6 @@
model_path=~/.paddlespeech/models/ model_path=~/.paddlespeech/models/
am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0/ ## fastspeech2_c am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0/
voc_model_dir=$model_path/mb_melgan_csmsc-zh/mb_melgan_csmsc_ckpt_0.1.1/ ## mb_melgan voc_model_dir=$model_path/mb_melgan_csmsc-zh/mb_melgan_csmsc_ckpt_0.1.1/
testdata=../../../../t2s/exps/csmsc_test.txt testdata=../../../../t2s/exps/csmsc_test.txt
# get am file # get am file
@ -33,9 +33,13 @@ done
# run test # run test
# am can choose fastspeech2_csmsc or fastspeech2-C_csmsc, where fastspeech2-C_csmsc supports streaming inference. # am can choose fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc, where fastspeech2_cnndecoder_csmsc supports streaming inference.
# voc can choose hifigan_csmsc and mb_melgan_csmsc, They can both support streaming inference. # voc can choose hifigan_csmsc and mb_melgan_csmsc, They can both support streaming inference.
python test_online_tts.py --am fastspeech2-C_csmsc \ # When am is fastspeech2_cnndecoder_csmsc and am_pad is set to 12, there is no diff between streaming and non-streaming inference results.
# When voc is mb_melgan_csmsc and voc_pad is set to 14, there is no diff between streaming and non-streaming inference results.
# When voc is hifigan_csmsc and voc_pad is set to 20, there is no diff between streaming and non-streaming inference results.
python test_online_tts.py --am fastspeech2_cnndecoder_csmsc \
--am_config $am_model_dir/$am_config_file \ --am_config $am_model_dir/$am_config_file \
--am_ckpt $am_model_dir/$am_ckpt_file \ --am_ckpt $am_model_dir/$am_ckpt_file \
--am_stat $am_model_dir/$am_stat_file \ --am_stat $am_model_dir/$am_stat_file \

@ -34,8 +34,8 @@ from paddlespeech.t2s.utils import str2bool
mel_streaming = None mel_streaming = None
wav_streaming = None wav_streaming = None
stream_first_time = 0.0 streaming_first_time = 0.0
voc_stream_st = 0.0 streaming_voc_st = 0.0
sample_rate = 0 sample_rate = 0
@ -65,7 +65,7 @@ def get_chunks(data, block_size, pad_size, step):
return chunks return chunks
def get_stream_am_inference(args, am_config): def get_streaming_am_inference(args, am_config):
with open(args.phones_dict, "r") as f: with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id) vocab_size = len(phn_id)
@ -99,8 +99,8 @@ def init(args):
frontend = get_frontend(args) frontend = get_frontend(args)
# acoustic model # acoustic model
if args.am == 'fastspeech2-C_csmsc': if args.am == 'fastspeech2_cnndecoder_csmsc':
am, am_mu, am_std = get_stream_am_inference(args, am_config) am, am_mu, am_std = get_streaming_am_inference(args, am_config)
am_infer_info = [am, am_mu, am_std, am_config] am_infer_info = [am, am_mu, am_std, am_config]
else: else:
am_inference, am_name, am_dataset = get_am_inference(args, am_config) am_inference, am_name, am_dataset = get_am_inference(args, am_config)
@ -139,7 +139,7 @@ def get_phone(args, frontend, sentence, merge_sentences, get_tone_ids):
# 生成完整的mel # 生成完整的mel
def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids): def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids):
# 如果是支持流式的AM模型 # 如果是支持流式的AM模型
if args.am == 'fastspeech2-C_csmsc': if args.am == 'fastspeech2_cnndecoder_csmsc':
am, am_mu, am_std, am_config = am_infer_info am, am_mu, am_std, am_config = am_infer_info
orig_hs, h_masks = am.encoder_infer(part_phone_ids) orig_hs, h_masks = am.encoder_infer(part_phone_ids)
if args.am_streaming: if args.am_streaming:
@ -183,9 +183,9 @@ def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids):
@paddle.no_grad() @paddle.no_grad()
def stream_voc_infer(args, voc_infer_info, mel_len): def streaming_voc_infer(args, voc_infer_info, mel_len):
global mel_streaming global mel_streaming
global stream_first_time global streaming_first_time
global wav_streaming global wav_streaming
voc_inference, voc_config = voc_infer_info voc_inference, voc_config = voc_infer_info
block = args.voc_block block = args.voc_block
@ -203,7 +203,7 @@ def stream_voc_infer(args, voc_infer_info, mel_len):
while valid_end <= mel_len: while valid_end <= mel_len:
sub_wav = voc_inference(mel_chunk) sub_wav = voc_inference(mel_chunk)
if flag == 1: if flag == 1:
stream_first_time = time.time() streaming_first_time = time.time()
flag = 0 flag = 0
# get valid wav # get valid wav
@ -233,7 +233,7 @@ def stream_voc_infer(args, voc_infer_info, mel_len):
@paddle.no_grad() @paddle.no_grad()
# 非流式AM / 流式AM + 非流式Voc # 非流式AM / 流式AM + 非流式Voc
def am_nostream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, def am_nonstreaming_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
part_tone_ids): part_tone_ids):
mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids) mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids)
am_infer_time = time.time() am_infer_time = time.time()
@ -248,10 +248,10 @@ def am_nostream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
@paddle.no_grad() @paddle.no_grad()
# 非流式AM + 流式Voc # 非流式AM + 流式Voc
def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, def nonstreaming_am_streaming_voc(args, am_infer_info, voc_infer_info,
part_tone_ids): part_phone_ids, part_tone_ids):
global mel_streaming global mel_streaming
global stream_first_time global streaming_first_time
global wav_streaming global wav_streaming
mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids) mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids)
@ -260,8 +260,8 @@ def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
# voc streaming # voc streaming
mel_streaming = mel mel_streaming = mel
mel_len = mel.shape[0] mel_len = mel.shape[0]
stream_voc_infer(args, voc_infer_info, mel_len) streaming_voc_infer(args, voc_infer_info, mel_len)
first_response_time = stream_first_time first_response_time = streaming_first_time
wav = wav_streaming wav = wav_streaming
final_response_time = time.time() final_response_time = time.time()
voc_infer_time = final_response_time voc_infer_time = final_response_time
@ -271,12 +271,12 @@ def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
@paddle.no_grad() @paddle.no_grad()
# 流式AM + 流式 Voc # 流式AM + 流式 Voc
def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, def streaming_am_streaming_voc(args, am_infer_info, voc_infer_info,
part_tone_ids): part_phone_ids, part_tone_ids):
global mel_streaming global mel_streaming
global stream_first_time global streaming_first_time
global wav_streaming global wav_streaming
global voc_stream_st global streaming_voc_st
mel_streaming = None mel_streaming = None
#用来表示开启流式voc的线程 #用来表示开启流式voc的线程
flag = 1 flag = 1
@ -311,15 +311,16 @@ def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids,
if flag and mel_streaming.shape[0] > args.voc_block + args.voc_pad: if flag and mel_streaming.shape[0] > args.voc_block + args.voc_pad:
t = threading.Thread( t = threading.Thread(
target=stream_voc_infer, args=(args, voc_infer_info, mel_len, )) target=streaming_voc_infer,
args=(args, voc_infer_info, mel_len, ))
t.start() t.start()
voc_stream_st = time.time() streaming_voc_st = time.time()
flag = 0 flag = 0
t.join() t.join()
final_response_time = time.time() final_response_time = time.time()
voc_infer_time = final_response_time voc_infer_time = final_response_time
first_response_time = stream_first_time first_response_time = streaming_first_time
wav = wav_streaming wav = wav_streaming
return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav
@ -337,11 +338,11 @@ def warm_up(args, logger, frontend, am_infer_info, voc_infer_info):
if args.voc_streaming: if args.voc_streaming:
if args.am_streaming: if args.am_streaming:
infer_func = stream_am_stream_voc infer_func = streaming_am_streaming_voc
else: else:
infer_func = nostream_am_stream_voc infer_func = nonstreaming_am_streaming_voc
else: else:
infer_func = am_nostream_voc infer_func = am_nonstreaming_voc
merge_sentences = True merge_sentences = True
get_tone_ids = False get_tone_ids = False
@ -376,11 +377,11 @@ def evaluate(args, logger, frontend, am_infer_info, voc_infer_info):
# choose infer function # choose infer function
if args.voc_streaming: if args.voc_streaming:
if args.am_streaming: if args.am_streaming:
infer_func = stream_am_stream_voc infer_func = streaming_am_streaming_voc
else: else:
infer_func = nostream_am_stream_voc infer_func = nonstreaming_am_streaming_voc
else: else:
infer_func = am_nostream_voc infer_func = am_nonstreaming_voc
final_up_duration = 0.0 final_up_duration = 0.0
sentence_count = 0 sentence_count = 0
@ -410,7 +411,7 @@ def evaluate(args, logger, frontend, am_infer_info, voc_infer_info):
args, am_infer_info, voc_infer_info, part_phone_ids, part_tone_ids) args, am_infer_info, voc_infer_info, part_phone_ids, part_tone_ids)
am_time = am_infer_time - am_st am_time = am_infer_time - am_st
if args.voc_streaming and args.am_streaming: if args.voc_streaming and args.am_streaming:
voc_time = voc_infer_time - voc_stream_st voc_time = voc_infer_time - streaming_voc_st
else: else:
voc_time = voc_infer_time - am_infer_time voc_time = voc_infer_time - am_infer_time
@ -482,8 +483,8 @@ def parse_args():
'--am', '--am',
type=str, type=str,
default='fastspeech2_csmsc', default='fastspeech2_csmsc',
choices=['fastspeech2_csmsc', 'fastspeech2-C_csmsc'], choices=['fastspeech2_csmsc', 'fastspeech2_cnndecoder_csmsc'],
help='Choose acoustic model type of tts task. where fastspeech2-C_csmsc supports streaming inference' help='Choose acoustic model type of tts task. where fastspeech2_cnndecoder_csmsc supports streaming inference'
) )
parser.add_argument( parser.add_argument(
@ -576,7 +577,7 @@ def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device) paddle.set_device(args.device)
if args.am_streaming: if args.am_streaming:
assert (args.am == 'fastspeech2-C_csmsc') assert (args.am == 'fastspeech2_cnndecoder_csmsc')
logger = logging.getLogger() logger = logging.getLogger()
fhandler = logging.FileHandler(filename=args.log_file, mode='w') fhandler = logging.FileHandler(filename=args.log_file, mode='w')

Loading…
Cancel
Save