From 82992b3ed6eaffd78fa27fae57235488f2ded168 Mon Sep 17 00:00:00 2001 From: lym0302 Date: Mon, 11 Apr 2022 11:00:04 +0800 Subject: [PATCH] add test code, test=doc --- .../server/tests/tts/infer/csmsc_test.txt | 100 +++ paddlespeech/server/tests/tts/infer/run.sh | 64 ++ .../server/tests/tts/infer/test_online_tts.py | 650 ++++++++++++++++++ 3 files changed, 814 insertions(+) create mode 100644 paddlespeech/server/tests/tts/infer/csmsc_test.txt create mode 100644 paddlespeech/server/tests/tts/infer/run.sh create mode 100644 paddlespeech/server/tests/tts/infer/test_online_tts.py diff --git a/paddlespeech/server/tests/tts/infer/csmsc_test.txt b/paddlespeech/server/tests/tts/infer/csmsc_test.txt new file mode 100644 index 000000000..d8cf367cd --- /dev/null +++ b/paddlespeech/server/tests/tts/infer/csmsc_test.txt @@ -0,0 +1,100 @@ +009901 昨日,这名伤者与医生全部被警方依法刑事拘留。 +009902 钱伟长想到上海来办学校是经过深思熟虑的。 +009903 她见我一进门就骂,吃饭时也骂,骂得我抬不起头。 +009904 李述德在离开之前,只说了一句柱驼杀父亲了。 +009905 这种车票和保险单捆绑出售属于重复性购买。 +009906 戴佩妮的男友西米露接唱情歌,让她非常开心。 +009907 观大势,谋大局,出大策始终是该院的办院方针。 +009908 他们骑着摩托回家,正好为农忙时的父母帮忙。 +009909 但是因为还没到退休年龄,只能掰着指头捱日子。 +009910 这几天雨水不断,人们恨不得待在家里不出门。 +009911 没想到徐赟,张海翔两人就此玩起了人间蒸发。 +009912 藤村此番发言可能是为了凸显野田的领导能力。 +009913 程长庚,生在清王朝嘉庆年间,安徽的潜山小县。 +009914 南海海域综合补给基地码头项目正在论证中。 +009915 也就是说今晚成都市民极有可能再次看到飘雪。 +009916 随着天气转热,各地的游泳场所开始人头攒动。 +009917 更让徐先生纳闷的是,房客的手机也打不通了。 +009918 遇到颠簸时,应听从乘务员的安全指令,回座位坐好。 +009919 他在后面呆惯了,怕自己一插身后的人会不满,不敢排进去。 +009920 傍晚七个小人回来了,白雪公主说,你们就是我命中的七个小矮人吧。 +009921 他本想说,教育局管这个,他们是一路的,这样一管岂不是妓女起嫖客? +009922 一种表示商品所有权的财物证券,也称商品证券,如提货单,交货单。 +009923 会有很丰富的东西留下来,说都说不完。 +009924 这句话像从天而降,吓得四周一片寂静。 +009925 记者所在的是受害人家属所在的右区。 +009926 不管哈大爷去哪,它都一步不离地跟着。 +009927 大家抬头望去,一只老鼠正趴在吊顶上。 +009928 我决定过年就辞职,接手我爸的废品站! +009929 最终,中国男子乒乓球队获得此奖项。 +009930 防汛抗旱两手抓,抗旱相对抓的不够。 +009931 图们江下游地区开发开放的进展如何? +009932 这要求中国必须有一个坚强的政党领导。 +009933 再说,关于利益上的事俺俩都不好开口。 +009934 明代瓦剌,鞑靼入侵明境也是通过此地。 +009935 咪咪舔着孩子,把它身上的毛舔干净。 +009936 是否这次的国标修订被大企业绑架了? +009937 判决后,姚某妻子胡某不服,提起上诉。 +009938 由此可以看出邯钢的经济效益来自何处。 +009939 琳达说,是瑜伽改变了她和马儿的生活。 +009940 楼下的保安告诉记者,这里不租也不卖。 +009941 习近平说,中斯两国人民传统友谊深厚。 +009942 传闻越来越多,后来连老汉儿自己都怕了。 +009943 我怒吼一声冲上去,举起砖头砸了过去。 +009944 我现在还不会,这就回去问问发明我的人。 +009945 显然,洛阳性奴案不具备上述两个前提。 +009946 另外,杰克逊有文唇线,眼线,眉毛的动作。 +009947 昨晚,华西都市报记者电话采访了尹琪。 +009948 涅拉季科未透露这些航空公司的名称。 +009949 从运行轨迹上来说,它也不可能是星星。 +009950 目前看,如果继续加息也存在两难问题。 +009951 曾宝仪在节目录制现场大爆观众糗事。 +009952 但任凭周某怎么叫,男子仍酣睡不醒。 +009953 老大爷说,小子,你挡我财路了,知道不? +009954 没料到,闯下大头佛的阿伟还不知悔改。 +009955 卡扎菲部落式统治已遭遇部落内讧。 +009956 这个孩子的生命一半来源于另一位女士捐赠的冷冻卵子。 +009957 出现这种泥鳅内阁的局面既是野田有意为之,也实属无奈。 +009958 济青高速济南,华山,章丘,邹平,周村,淄博,临淄站。 +009959 赵凌飞的话,反映了沈阳赛区所有奥运志愿者的共同心声。 +009960 因为,我们所发出的力量必会因难度加大而减弱。 +009961 发生事故的楼梯拐角处仍可看到血迹。 +009962 想过进公安,可能身高不够,老汉儿也不让我进去。 +009963 路上关卡很多,为了方便撤离,只好轻装前进。 +009964 原来比尔盖茨就是美国微软公司联合创始人呀。 +009965 之后他们一家三口将与双方父母往峇里岛旅游。 +009966 谢谢总理,也感谢广大网友的参与,我们明年再见。 +009967 事实上是,从来没有一个欺善怕恶的人能作出过稍大一点的成就。 +009968 我会打开邮件,你可以从那里继续。 +009969 美方对近期东海局势表示关切。 +009970 据悉,奥巴马一家人对这座冬季白宫极为满意。 +009971 打扫完你会很有成就感的,试一试,你就信了。 +009972 诺曼站在滑板车上,各就各位,准备出发啦! +009973 塔河的寒夜,气温降到了零下三十多摄氏度。 +009974 其间,连破六点六,六点五,六点四,六点三五等多个重要关口。 +009975 算命其实只是人们的一种自我安慰和自我暗示而已,我们还是要相信科学才好。 +009976 这一切都令人欢欣鼓舞,阿讷西没理由不坚持到最后。 +009977 直至公元前一万一千年,它又再次出现。 +009978 尽量少玩电脑,少看电视,少打游戏。 +009979 从五到七,前后也就是六个月的时间。 +009980 一进咖啡店,他就遇见一张熟悉的脸。 +009981 好在众弟兄看到了把她追了回来。 +009982 有一个人说,哥们儿我们跑过它才能活。 +009983 捅了她以后,模糊记得她没咋动了。 +009984 从小到大,葛启义没有收到过压岁钱。 +009985 舞台下的你会对舞台上的你说什么? +009986 但考生普遍认为,试题的怪多过难。 +009987 我希望每个人都能够尊重我们的隐私。 +009988 漫天的红霞使劲给两人增添气氛。 +009989 晚上加完班开车回家,太累了,迷迷糊糊开着车,走一半的时候,铛一声! +009990 该车将三人撞倒后,在大雾中逃窜。 +009991 这人一哆嗦,方向盘也把不稳了,差点撞上了高速边道护栏。 +009992 那女孩儿委屈的说,我一回头见你已经进去了我不敢进去啊! +009993 小明摇摇头说,不是,我只是美女看多了,想换个口味而已。 +009994 接下来,红娘要求记者交费,记者表示不知表姐身份证号码。 +009995 李东蓊表示,自己当时在法庭上发表了一次独特的公诉意见。 +009996 另一男子扑了上来,手里拿着明晃晃的长刀,向他胸口直刺。 +009997 今天,快递员拿着一个快递在办公室喊,秦王是哪个,有他快递? +009998 这场抗议活动究竟是如何发展演变的,又究竟是谁伤害了谁? +009999 因华国锋肖鸡,墓地设计根据其属相设计。 +010000 在狱中,张明宝悔恨交加,写了一份忏悔书。 diff --git a/paddlespeech/server/tests/tts/infer/run.sh b/paddlespeech/server/tests/tts/infer/run.sh new file mode 100644 index 000000000..fdceec412 --- /dev/null +++ b/paddlespeech/server/tests/tts/infer/run.sh @@ -0,0 +1,64 @@ +model_path=/home/users/liangyunming/.paddlespeech/models/ +#am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_nosil_baker_ckpt_0.4/ ## fastspeech2 +am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0/ ## fastspeech2_cnn +voc_model_dir=$model_path/hifigan_csmsc-zh/hifigan_csmsc_ckpt_0.1.1/ ## hifigan +#voc_model_dir=$model_path/mb_melgan_csmsc-zh/mb_melgan_csmsc_ckpt_0.1.1/ ## mb_melgan + +if [[ $am_model_dir == *"fastspeech2_cnndecoder"* ]]; then + am_support_stream=True +else + am_support_stream=False +fi + +# get am file +for file in $(ls $am_model_dir) +do + if [[ $file == *"yaml"* ]]; then + am_config_file=$file + elif [[ $file == *"pdz"* ]]; then + am_ckpt_file=$file + elif [[ $file == *"stat"* ]]; then + am_stat_file=$file + elif [[ $file == *"phone"* ]]; then + phones_dict_file=$file + fi + +done + +# get voc file +for file in $(ls $voc_model_dir) +do + if [[ $file == *"yaml"* ]]; then + voc_config_file=$file + elif [[ $file == *"pdz"* ]]; then + voc_ckpt_file=$file + elif [[ $file == *"stat"* ]]; then + voc_stat_file=$file + fi + +done + + +#run +python test_online_tts.py --am fastspeech2_csmsc \ + --am_support_stream $am_support_stream \ + --am_config $am_model_dir/$am_config_file \ + --am_ckpt $am_model_dir/$am_ckpt_file \ + --am_stat $am_model_dir/$am_stat_file \ + --phones_dict $am_model_dir/$phones_dict_file \ + --voc hifigan_csmsc \ + --voc_config $voc_model_dir/$voc_config_file \ + --voc_ckpt $voc_model_dir/$voc_ckpt_file \ + --voc_stat $voc_model_dir/$voc_stat_file \ + --lang zh \ + --device cpu \ + --text ./csmsc_test.txt \ + --output_dir ./output \ + --log_file ./result.log \ + --am_streaming False \ + --am_pad 12 \ + --am_block 42 \ + --voc_streaming True \ + --voc_pad 14 \ + --voc_block 14 \ + diff --git a/paddlespeech/server/tests/tts/infer/test_online_tts.py b/paddlespeech/server/tests/tts/infer/test_online_tts.py new file mode 100644 index 000000000..17ac0ea7b --- /dev/null +++ b/paddlespeech/server/tests/tts/infer/test_online_tts.py @@ -0,0 +1,650 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import math +import threading +import time +from pathlib import Path + +import numpy as np +import paddle +import soundfile as sf +import yaml +from yacs.config import CfgNode + +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.t2s.exps.syn_utils import get_am_inference +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.exps.syn_utils import get_voc_inference +from paddlespeech.t2s.exps.syn_utils import model_alias +from paddlespeech.t2s.utils import str2bool + +mel_streaming = None +wav_streaming = None +stream_first_time = 0.0 +voc_stream_st = 0.0 +sample_rate = 0 + + +def denorm(data, mean, std): + return data * std + mean + + +def get_chunks(data, block_size, pad_size, step): + if step == "am": + data_len = data.shape[1] + elif step == "voc": + data_len = data.shape[0] + else: + print("Please set correct type to get chunks, am or voc") + + chunks = [] + n = math.ceil(data_len / block_size) + for i in range(n): + start = max(0, i * block_size - pad_size) + end = min((i + 1) * block_size + pad_size, data_len) + if step == "am": + chunks.append(data[:, start:end, :]) + elif step == "voc": + chunks.append(data[start:end, :]) + else: + print("Please set correct type to get chunks, am or voc") + return chunks + + +def get_stream_am_inference(args, am_config): + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + odim = am_config.n_mels + + am_class = dynamic_import(am_name, model_alias) + am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) + am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) + am.eval() + am_mu, am_std = np.load(args.am_stat) + am_mu = paddle.to_tensor(am_mu) + am_std = paddle.to_tensor(am_std) + + return am, am_mu, am_std + + +def init(args): + global sample_rate + # get config + with open(args.am_config) as f: + am_config = CfgNode(yaml.safe_load(f)) + with open(args.voc_config) as f: + voc_config = CfgNode(yaml.safe_load(f)) + + sample_rate = am_config.fs + + # frontend + frontend = get_frontend(args) + + # acoustic model + if args.am_support_stream: + am, am_mu, am_std = get_stream_am_inference(args, am_config) + am_infer_info = [am, am_mu, am_std, am_config] + else: + am_inference, am_name, am_dataset = get_am_inference(args, am_config) + am_infer_info = [am_inference, am_name, am_dataset, am_config] + + # vocoder + voc_inference = get_voc_inference(args, voc_config) + voc_infer_info = [voc_inference, voc_config] + + return frontend, am_infer_info, voc_infer_info + + +def get_phone(args, frontend, sentence, merge_sentences, get_tone_ids): + am_name = args.am[:args.am.rindex('_')] + tone_ids = None + if am_name == 'speedyspeech': + get_tone_ids = True + + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + elif args.lang == 'en': + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + + return phone_ids, tone_ids + + +@paddle.no_grad() +# 生成完整的mel +def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids): + # 如果是支持流式的AM模型 + if args.am_support_stream: + am, am_mu, am_std, am_config = am_infer_info + orig_hs, h_masks = am.encoder_infer(part_phone_ids) + if args.am_streaming: + am_pad = args.am_pad + am_block = args.am_block + hss = get_chunks(orig_hs, am_block, am_pad, "am") + chunk_num = len(hss) + mel_list = [] + for i, hs in enumerate(hss): + before_outs, _ = am.decoder(hs) + after_outs = before_outs + am.postnet( + before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + normalized_mel = after_outs[0] + sub_mel = denorm(normalized_mel, am_mu, am_std) + # clip output part of pad + if i == 0: + sub_mel = sub_mel[:-am_pad] + elif i == chunk_num - 1: + # 最后一块的右侧一定没有 pad 够 + sub_mel = sub_mel[am_pad:] + else: + # 倒数几块的右侧也可能没有 pad 够 + sub_mel = sub_mel[am_pad:(am_block + am_pad) - + sub_mel.shape[0]] + mel_list.append(sub_mel) + mel = paddle.concat(mel_list, axis=0) + + else: + orig_hs, h_masks = am.encoder_infer(part_phone_ids) + before_outs, _ = am.decoder(orig_hs) + after_outs = before_outs + am.postnet( + before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + normalized_mel = after_outs[0] + mel = denorm(normalized_mel, am_mu, am_std) + + else: + am_inference, am_name, am_dataset, am_config = am_infer_info + # acoustic model + if am_name == 'fastspeech2': + # multi speaker + if am_dataset in {"aishell3", "vctk"}: + spk_id = paddle.to_tensor(args.spk_id) + mel = am_inference(part_phone_ids, spk_id) + else: + mel = am_inference(part_phone_ids) + elif am_name == 'speedyspeech': + part_tone_ids = tone_ids[i] + if am_dataset in {"aishell3", "vctk"}: + spk_id = paddle.to_tensor(args.spk_id) + mel = am_inference(part_phone_ids, part_tone_ids, spk_id) + else: + mel = am_inference(part_phone_ids, part_tone_ids) + elif am_name == 'tacotron2': + mel = am_inference(part_phone_ids) + + return mel + + +@paddle.no_grad() +def stream_voc_infer(args, voc_infer_info, mel_len): + global mel_streaming + global stream_first_time + global wav_streaming + voc_inference, voc_config = voc_infer_info + block = args.voc_block + pad = args.voc_pad + upsample = voc_config.n_shift + wav_list = [] + flag = 1 + + valid_start = 0 + valid_end = min(valid_start + block, mel_len) + actual_start = 0 + actual_end = min(valid_end + pad, mel_len) + mel_chunk = mel_streaming[actual_start:actual_end, :] + + while valid_end <= mel_len: + sub_wav = voc_inference(mel_chunk) + if flag == 1: + stream_first_time = time.time() + flag = 0 + + # get valid wav + start = valid_start - actual_start + if valid_end == mel_len: + sub_wav = sub_wav[start * upsample:] + wav_list.append(sub_wav) + break + else: + end = start + block + sub_wav = sub_wav[start * upsample:end * upsample] + wav_list.append(sub_wav) + + # generate new mel chunk + valid_start = valid_end + valid_end = min(valid_start + block, mel_len) + if valid_start - pad < 0: + actual_start = 0 + else: + actual_start = valid_start - pad + actual_end = min(valid_end + pad, mel_len) + mel_chunk = mel_streaming[actual_start:actual_end, :] + + wav = paddle.concat(wav_list, axis=0) + wav_streaming = wav + + +@paddle.no_grad() +# 非流式AM / 流式AM + 非流式Voc +def am_nostream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, + part_tone_ids): + mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids) + am_infer_time = time.time() + voc_inference, voc_config = voc_infer_info + wav = voc_inference(mel) + first_response_time = time.time() + final_response_time = first_response_time + voc_infer_time = first_response_time + + return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav + + +@paddle.no_grad() +# 非流式AM + 流式Voc +def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, + part_tone_ids): + global mel_streaming + global stream_first_time + global wav_streaming + + mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids) + am_infer_time = time.time() + + # voc streaming + mel_streaming = mel + mel_len = mel.shape[0] + stream_voc_infer(args, voc_infer_info, mel_len) + first_response_time = stream_first_time + wav = wav_streaming + final_response_time = time.time() + voc_infer_time = final_response_time + + return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav + + +@paddle.no_grad() +# 流式AM + 流式 Voc +def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, + part_tone_ids): + global mel_streaming + global stream_first_time + global wav_streaming + global voc_stream_st + mel_streaming = None + flag = 1 #用来表示开启流式voc的线程 + + am, am_mu, am_std, am_config = am_infer_info + orig_hs, h_masks = am.encoder_infer(part_phone_ids) + mel_len = orig_hs.shape[1] + am_block = args.am_block + am_pad = args.am_pad + hss = get_chunks(orig_hs, am_block, am_pad, "am") + chunk_num = len(hss) + + for i, hs in enumerate(hss): + before_outs, _ = am.decoder(hs) + after_outs = before_outs + am.postnet( + before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + normalized_mel = after_outs[0] + sub_mel = denorm(normalized_mel, am_mu, am_std) + # clip output part of pad + if i == 0: + sub_mel = sub_mel[:-am_pad] + mel_streaming = sub_mel + elif i == chunk_num - 1: + # 最后一块的右侧一定没有 pad 够 + sub_mel = sub_mel[am_pad:] + mel_streaming = paddle.concat([mel_streaming, sub_mel]) + am_infer_time = time.time() + else: + # 倒数几块的右侧也可能没有 pad 够 + sub_mel = sub_mel[am_pad:(am_block + am_pad) - sub_mel.shape[0]] + mel_streaming = paddle.concat([mel_streaming, sub_mel]) + + if flag and mel_streaming.shape[0] > args.voc_block + args.voc_pad: + t = threading.Thread( + target=stream_voc_infer, args=(args, voc_infer_info, mel_len, )) + t.start() + voc_stream_st = time.time() + flag = 0 + + t.join() + final_response_time = time.time() + voc_infer_time = final_response_time + first_response_time = stream_first_time + wav = wav_streaming + + return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav + + +def try_infer(args, logger, frontend, am_infer_info, voc_infer_info): + global sample_rate + logger.info( + "Before the formal test, we test a few texts to make the inference speed more stable." + ) + if args.lang == 'zh': + sentence = "您好,欢迎使用语音合成服务。" + if args.lang == 'en': + sentence = "Hello and welcome to the speech synthesis service." + + if args.voc_streaming: + if args.am_streaming: + infer_func = stream_am_stream_voc + else: + infer_func = nostream_am_stream_voc + else: + infer_func = am_nostream_voc + + merge_sentences = True + get_tone_ids = False + for i in range(3): # 推理3次 + st = time.time() + phone_ids, tone_ids = get_phone(args, frontend, sentence, + merge_sentences, get_tone_ids) + part_phone_ids = phone_ids[0] + if tone_ids: + part_tone_ids = tone_ids[0] + else: + part_tone_ids = None + + am_infer_time, voc_infer_time, first_response_time, final_response_time, wav = infer_func( + args, am_infer_info, voc_infer_info, part_phone_ids, part_tone_ids) + wav = wav.numpy() + duration = wav.size / sample_rate + logger.info( + f"sentence: {sentence}; duration: {duration} s; first response time: {first_response_time - st} s; final response time: {final_response_time - st} s" + ) + + +def evaluate(args, logger, frontend, am_infer_info, voc_infer_info): + global sample_rate + sentences = get_sentences(args) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + get_tone_ids = False + merge_sentences = True + + # choose infer function + if args.voc_streaming: + if args.am_streaming: + infer_func = stream_am_stream_voc + else: + infer_func = nostream_am_stream_voc + else: + infer_func = am_nostream_voc + + final_up_duration = 0.0 + sentence_count = 0 + front_time_list = [] + am_time_list = [] + voc_time_list = [] + first_response_list = [] + final_response_list = [] + sentence_length_list = [] + duration_list = [] + + for utt_id, sentence in sentences: + # front + front_st = time.time() + phone_ids, tone_ids = get_phone(args, frontend, sentence, + merge_sentences, get_tone_ids) + part_phone_ids = phone_ids[0] + if tone_ids: + part_tone_ids = tone_ids[0] + else: + part_tone_ids = None + front_et = time.time() + front_time = front_et - front_st + + am_st = time.time() + am_infer_time, voc_infer_time, first_response_time, final_response_time, wav = infer_func( + args, am_infer_info, voc_infer_info, part_phone_ids, part_tone_ids) + am_time = am_infer_time - am_st + if args.voc_streaming and args.am_streaming: + voc_time = voc_infer_time - voc_stream_st + else: + voc_time = voc_infer_time - am_infer_time + + first_response = first_response_time - front_st + final_response = final_response_time - front_st + + wav = wav.numpy() + duration = wav.size / sample_rate + sf.write( + str(output_dir / (utt_id + ".wav")), wav, samplerate=sample_rate) + print(f"{utt_id} done!") + + sentence_count += 1 + front_time_list.append(front_time) + am_time_list.append(am_time) + voc_time_list.append(voc_time) + first_response_list.append(first_response) + final_response_list.append(final_response) + sentence_length_list.append(len(sentence)) + duration_list.append(duration) + + logger.info( + f"uttid: {utt_id}; sentence: '{sentence}'; front time: {front_time} s; am time: {am_time} s; voc time: {voc_time} s; \ + first response time: {first_response} s; final response time: {final_response} s; audio duration: {duration} s;" + ) + + if final_response > duration: + final_up_duration += 1 + + all_time_sum = sum(final_response_list) + front_rate = sum(front_time_list) / all_time_sum + am_rate = sum(am_time_list) / all_time_sum + voc_rate = sum(voc_time_list) / all_time_sum + rtf = all_time_sum / sum(duration_list) + + logger.info( + f"The length of test text information, test num: {sentence_count}; text num: {sum(sentence_length_list)}; min: {min(sentence_length_list)}; max: {max(sentence_length_list)}; avg: {sum(sentence_length_list)/len(sentence_length_list)}" + ) + logger.info( + f"duration information, min: {min(duration_list)}; max: {max(duration_list)}; avg: {sum(duration_list) / len(duration_list)}; sum: {sum(duration_list)}" + ) + logger.info( + f"Front time information: min: {min(front_time_list)} s; max: {max(front_time_list)} s; avg: {sum(front_time_list)/len(front_time_list)} s; ratio: {front_rate * 100}%" + ) + logger.info( + f"AM time information: min: {min(am_time_list)} s; max: {max(am_time_list)} s; avg: {sum(am_time_list)/len(am_time_list)} s; ratio: {am_rate * 100}%" + ) + logger.info( + f"Vocoder time information: min: {min(voc_time_list)} s, max: {max(voc_time_list)} s; avg: {sum(voc_time_list)/len(voc_time_list)} s; ratio: {voc_rate * 100}%" + ) + logger.info( + f"first response time information: min: {min(first_response_list)} s; max: {max(first_response_list)} s; avg: {sum(first_response_list)/len(first_response_list)} s" + ) + logger.info( + f"final response time information: min: {min(final_response_list)} s; max: {max(final_response_list)} s; avg: {sum(final_response_list)/len(final_response_list)} s" + ) + logger.info(f"RTF is: {rtf}") + logger.info( + f"The number of final_response is greater than duration is {final_up_duration}, ratio: {final_up_duration / sentence_count}%" + ) + + +def parse_args(): + # parse args and config and redirect to train_sp + parser = argparse.ArgumentParser( + description="Synthesize with acoustic model & vocoder") + # acoustic model + parser.add_argument( + '--am', + type=str, + default='fastspeech2_csmsc', + choices=[ + 'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc', + 'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk', + 'tacotron2_csmsc', 'tacotron2_ljspeech' + ], + help='Choose acoustic model type of tts task.') + parser.add_argument( + '--am_support_stream', + type=str2bool, + default=False, + help='if am model is fastspeech2_csmsc, specify whether it supports streaming' + ) + parser.add_argument( + '--am_config', + type=str, + default=None, + help='Config of acoustic model. Use deault config when it is None.') + parser.add_argument( + '--am_ckpt', + type=str, + default=None, + help='Checkpoint file of acoustic model.') + parser.add_argument( + "--am_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training acoustic model." + ) + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--tones_dict", type=str, default=None, help="tone vocabulary file.") + parser.add_argument( + "--speaker_dict", type=str, default=None, help="speaker id map file.") + parser.add_argument( + '--spk_id', + type=int, + default=0, + help='spk id for multi speaker acoustic model') + # vocoder + parser.add_argument( + '--voc', + type=str, + default='mb_melgan_csmsc', + choices=[ + 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk', + 'mb_melgan_csmsc', 'style_melgan_csmsc', 'hifigan_csmsc', + 'wavernn_csmsc' + ], + help='Choose vocoder type of tts task.') + parser.add_argument( + '--voc_config', + type=str, + default=None, + help='Config of voc. Use deault config when it is None.') + parser.add_argument( + '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') + parser.add_argument( + "--voc_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training voc." + ) + # other + parser.add_argument( + '--lang', + type=str, + default='zh', + choices=['zh', 'en'], + help='Choose model language. zh or en') + + parser.add_argument( + "--device", type=str, default='cpu', help="set cpu or gpu:id") + + parser.add_argument( + "--text", + type=str, + default="./csmsc_test.txt", + help="text to synthesize, a 'utt_id sentence' pair per line.") + parser.add_argument("--output_dir", type=str, help="output dir.") + parser.add_argument( + "--log_file", type=str, default="result.log", help="log file.") + + parser.add_argument( + "--am_streaming", + type=str2bool, + default=False, + help="whether use streaming acoustic model") + + parser.add_argument("--am_pad", type=int, default=12, help="am pad size.") + + parser.add_argument( + "--am_block", type=int, default=42, help="am block size.") + + parser.add_argument( + "--voc_streaming", + type=str2bool, + default=False, + help="whether use streaming vocoder model") + + parser.add_argument("--voc_pad", type=int, default=14, help="voc pad size.") + + parser.add_argument( + "--voc_block", type=int, default=14, help="voc block size.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + paddle.set_device(args.device) + if args.am_support_stream: + assert (args.am == 'fastspeech2_csmsc') + if args.am_streaming: + assert (args.am_support_stream and args.am == 'fastspeech2_csmsc') + if args.voc_streaming: + assert (args.voc == 'mb_melgan_csmsc' or args.voc == 'hifigan_csmsc') + + logger = logging.getLogger() + fhandler = logging.FileHandler(filename=args.log_file, mode='w') + formatter = logging.Formatter( + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + ) + fhandler.setFormatter(formatter) + logger.addHandler(fhandler) + logger.setLevel(logging.DEBUG) + + # set basic information + logger.info( + f"AM: {args.am}; Vocoder: {args.voc}; device: {args.device}; am streaming: {args.am_streaming}; voc streaming: {args.voc_streaming}" + ) + logger.info( + f"am pad size: {args.am_pad}; am block size: {args.am_block}; voc pad size: {args.voc_pad}; voc block size: {args.voc_block};" + ) + + # get information about model + frontend, am_infer_info, voc_infer_info = init(args) + logger.info( + "************************ try infer *********************************") + try_infer(args, logger, frontend, am_infer_info, voc_infer_info) + logger.info( + "************************ normal test *******************************") + evaluate(args, logger, frontend, am_infer_info, voc_infer_info) + + +if __name__ == "__main__": + main()