From 82992b3ed6eaffd78fa27fae57235488f2ded168 Mon Sep 17 00:00:00 2001 From: lym0302 Date: Mon, 11 Apr 2022 11:00:04 +0800 Subject: [PATCH 01/31] add test code, test=doc --- .../server/tests/tts/infer/csmsc_test.txt | 100 +++ paddlespeech/server/tests/tts/infer/run.sh | 64 ++ .../server/tests/tts/infer/test_online_tts.py | 650 ++++++++++++++++++ 3 files changed, 814 insertions(+) create mode 100644 paddlespeech/server/tests/tts/infer/csmsc_test.txt create mode 100644 paddlespeech/server/tests/tts/infer/run.sh create mode 100644 paddlespeech/server/tests/tts/infer/test_online_tts.py diff --git a/paddlespeech/server/tests/tts/infer/csmsc_test.txt b/paddlespeech/server/tests/tts/infer/csmsc_test.txt new file mode 100644 index 00000000..d8cf367c --- /dev/null +++ b/paddlespeech/server/tests/tts/infer/csmsc_test.txt @@ -0,0 +1,100 @@ +009901 昨日,这名伤者与医生全部被警方依法刑事拘留。 +009902 钱伟长想到上海来办学校是经过深思熟虑的。 +009903 她见我一进门就骂,吃饭时也骂,骂得我抬不起头。 +009904 李述德在离开之前,只说了一句柱驼杀父亲了。 +009905 这种车票和保险单捆绑出售属于重复性购买。 +009906 戴佩妮的男友西米露接唱情歌,让她非常开心。 +009907 观大势,谋大局,出大策始终是该院的办院方针。 +009908 他们骑着摩托回家,正好为农忙时的父母帮忙。 +009909 但是因为还没到退休年龄,只能掰着指头捱日子。 +009910 这几天雨水不断,人们恨不得待在家里不出门。 +009911 没想到徐赟,张海翔两人就此玩起了人间蒸发。 +009912 藤村此番发言可能是为了凸显野田的领导能力。 +009913 程长庚,生在清王朝嘉庆年间,安徽的潜山小县。 +009914 南海海域综合补给基地码头项目正在论证中。 +009915 也就是说今晚成都市民极有可能再次看到飘雪。 +009916 随着天气转热,各地的游泳场所开始人头攒动。 +009917 更让徐先生纳闷的是,房客的手机也打不通了。 +009918 遇到颠簸时,应听从乘务员的安全指令,回座位坐好。 +009919 他在后面呆惯了,怕自己一插身后的人会不满,不敢排进去。 +009920 傍晚七个小人回来了,白雪公主说,你们就是我命中的七个小矮人吧。 +009921 他本想说,教育局管这个,他们是一路的,这样一管岂不是妓女起嫖客? +009922 一种表示商品所有权的财物证券,也称商品证券,如提货单,交货单。 +009923 会有很丰富的东西留下来,说都说不完。 +009924 这句话像从天而降,吓得四周一片寂静。 +009925 记者所在的是受害人家属所在的右区。 +009926 不管哈大爷去哪,它都一步不离地跟着。 +009927 大家抬头望去,一只老鼠正趴在吊顶上。 +009928 我决定过年就辞职,接手我爸的废品站! +009929 最终,中国男子乒乓球队获得此奖项。 +009930 防汛抗旱两手抓,抗旱相对抓的不够。 +009931 图们江下游地区开发开放的进展如何? +009932 这要求中国必须有一个坚强的政党领导。 +009933 再说,关于利益上的事俺俩都不好开口。 +009934 明代瓦剌,鞑靼入侵明境也是通过此地。 +009935 咪咪舔着孩子,把它身上的毛舔干净。 +009936 是否这次的国标修订被大企业绑架了? +009937 判决后,姚某妻子胡某不服,提起上诉。 +009938 由此可以看出邯钢的经济效益来自何处。 +009939 琳达说,是瑜伽改变了她和马儿的生活。 +009940 楼下的保安告诉记者,这里不租也不卖。 +009941 习近平说,中斯两国人民传统友谊深厚。 +009942 传闻越来越多,后来连老汉儿自己都怕了。 +009943 我怒吼一声冲上去,举起砖头砸了过去。 +009944 我现在还不会,这就回去问问发明我的人。 +009945 显然,洛阳性奴案不具备上述两个前提。 +009946 另外,杰克逊有文唇线,眼线,眉毛的动作。 +009947 昨晚,华西都市报记者电话采访了尹琪。 +009948 涅拉季科未透露这些航空公司的名称。 +009949 从运行轨迹上来说,它也不可能是星星。 +009950 目前看,如果继续加息也存在两难问题。 +009951 曾宝仪在节目录制现场大爆观众糗事。 +009952 但任凭周某怎么叫,男子仍酣睡不醒。 +009953 老大爷说,小子,你挡我财路了,知道不? +009954 没料到,闯下大头佛的阿伟还不知悔改。 +009955 卡扎菲部落式统治已遭遇部落内讧。 +009956 这个孩子的生命一半来源于另一位女士捐赠的冷冻卵子。 +009957 出现这种泥鳅内阁的局面既是野田有意为之,也实属无奈。 +009958 济青高速济南,华山,章丘,邹平,周村,淄博,临淄站。 +009959 赵凌飞的话,反映了沈阳赛区所有奥运志愿者的共同心声。 +009960 因为,我们所发出的力量必会因难度加大而减弱。 +009961 发生事故的楼梯拐角处仍可看到血迹。 +009962 想过进公安,可能身高不够,老汉儿也不让我进去。 +009963 路上关卡很多,为了方便撤离,只好轻装前进。 +009964 原来比尔盖茨就是美国微软公司联合创始人呀。 +009965 之后他们一家三口将与双方父母往峇里岛旅游。 +009966 谢谢总理,也感谢广大网友的参与,我们明年再见。 +009967 事实上是,从来没有一个欺善怕恶的人能作出过稍大一点的成就。 +009968 我会打开邮件,你可以从那里继续。 +009969 美方对近期东海局势表示关切。 +009970 据悉,奥巴马一家人对这座冬季白宫极为满意。 +009971 打扫完你会很有成就感的,试一试,你就信了。 +009972 诺曼站在滑板车上,各就各位,准备出发啦! +009973 塔河的寒夜,气温降到了零下三十多摄氏度。 +009974 其间,连破六点六,六点五,六点四,六点三五等多个重要关口。 +009975 算命其实只是人们的一种自我安慰和自我暗示而已,我们还是要相信科学才好。 +009976 这一切都令人欢欣鼓舞,阿讷西没理由不坚持到最后。 +009977 直至公元前一万一千年,它又再次出现。 +009978 尽量少玩电脑,少看电视,少打游戏。 +009979 从五到七,前后也就是六个月的时间。 +009980 一进咖啡店,他就遇见一张熟悉的脸。 +009981 好在众弟兄看到了把她追了回来。 +009982 有一个人说,哥们儿我们跑过它才能活。 +009983 捅了她以后,模糊记得她没咋动了。 +009984 从小到大,葛启义没有收到过压岁钱。 +009985 舞台下的你会对舞台上的你说什么? +009986 但考生普遍认为,试题的怪多过难。 +009987 我希望每个人都能够尊重我们的隐私。 +009988 漫天的红霞使劲给两人增添气氛。 +009989 晚上加完班开车回家,太累了,迷迷糊糊开着车,走一半的时候,铛一声! +009990 该车将三人撞倒后,在大雾中逃窜。 +009991 这人一哆嗦,方向盘也把不稳了,差点撞上了高速边道护栏。 +009992 那女孩儿委屈的说,我一回头见你已经进去了我不敢进去啊! +009993 小明摇摇头说,不是,我只是美女看多了,想换个口味而已。 +009994 接下来,红娘要求记者交费,记者表示不知表姐身份证号码。 +009995 李东蓊表示,自己当时在法庭上发表了一次独特的公诉意见。 +009996 另一男子扑了上来,手里拿着明晃晃的长刀,向他胸口直刺。 +009997 今天,快递员拿着一个快递在办公室喊,秦王是哪个,有他快递? +009998 这场抗议活动究竟是如何发展演变的,又究竟是谁伤害了谁? +009999 因华国锋肖鸡,墓地设计根据其属相设计。 +010000 在狱中,张明宝悔恨交加,写了一份忏悔书。 diff --git a/paddlespeech/server/tests/tts/infer/run.sh b/paddlespeech/server/tests/tts/infer/run.sh new file mode 100644 index 00000000..fdceec41 --- /dev/null +++ b/paddlespeech/server/tests/tts/infer/run.sh @@ -0,0 +1,64 @@ +model_path=/home/users/liangyunming/.paddlespeech/models/ +#am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_nosil_baker_ckpt_0.4/ ## fastspeech2 +am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0/ ## fastspeech2_cnn +voc_model_dir=$model_path/hifigan_csmsc-zh/hifigan_csmsc_ckpt_0.1.1/ ## hifigan +#voc_model_dir=$model_path/mb_melgan_csmsc-zh/mb_melgan_csmsc_ckpt_0.1.1/ ## mb_melgan + +if [[ $am_model_dir == *"fastspeech2_cnndecoder"* ]]; then + am_support_stream=True +else + am_support_stream=False +fi + +# get am file +for file in $(ls $am_model_dir) +do + if [[ $file == *"yaml"* ]]; then + am_config_file=$file + elif [[ $file == *"pdz"* ]]; then + am_ckpt_file=$file + elif [[ $file == *"stat"* ]]; then + am_stat_file=$file + elif [[ $file == *"phone"* ]]; then + phones_dict_file=$file + fi + +done + +# get voc file +for file in $(ls $voc_model_dir) +do + if [[ $file == *"yaml"* ]]; then + voc_config_file=$file + elif [[ $file == *"pdz"* ]]; then + voc_ckpt_file=$file + elif [[ $file == *"stat"* ]]; then + voc_stat_file=$file + fi + +done + + +#run +python test_online_tts.py --am fastspeech2_csmsc \ + --am_support_stream $am_support_stream \ + --am_config $am_model_dir/$am_config_file \ + --am_ckpt $am_model_dir/$am_ckpt_file \ + --am_stat $am_model_dir/$am_stat_file \ + --phones_dict $am_model_dir/$phones_dict_file \ + --voc hifigan_csmsc \ + --voc_config $voc_model_dir/$voc_config_file \ + --voc_ckpt $voc_model_dir/$voc_ckpt_file \ + --voc_stat $voc_model_dir/$voc_stat_file \ + --lang zh \ + --device cpu \ + --text ./csmsc_test.txt \ + --output_dir ./output \ + --log_file ./result.log \ + --am_streaming False \ + --am_pad 12 \ + --am_block 42 \ + --voc_streaming True \ + --voc_pad 14 \ + --voc_block 14 \ + diff --git a/paddlespeech/server/tests/tts/infer/test_online_tts.py b/paddlespeech/server/tests/tts/infer/test_online_tts.py new file mode 100644 index 00000000..17ac0ea7 --- /dev/null +++ b/paddlespeech/server/tests/tts/infer/test_online_tts.py @@ -0,0 +1,650 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import math +import threading +import time +from pathlib import Path + +import numpy as np +import paddle +import soundfile as sf +import yaml +from yacs.config import CfgNode + +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.t2s.exps.syn_utils import get_am_inference +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.exps.syn_utils import get_voc_inference +from paddlespeech.t2s.exps.syn_utils import model_alias +from paddlespeech.t2s.utils import str2bool + +mel_streaming = None +wav_streaming = None +stream_first_time = 0.0 +voc_stream_st = 0.0 +sample_rate = 0 + + +def denorm(data, mean, std): + return data * std + mean + + +def get_chunks(data, block_size, pad_size, step): + if step == "am": + data_len = data.shape[1] + elif step == "voc": + data_len = data.shape[0] + else: + print("Please set correct type to get chunks, am or voc") + + chunks = [] + n = math.ceil(data_len / block_size) + for i in range(n): + start = max(0, i * block_size - pad_size) + end = min((i + 1) * block_size + pad_size, data_len) + if step == "am": + chunks.append(data[:, start:end, :]) + elif step == "voc": + chunks.append(data[start:end, :]) + else: + print("Please set correct type to get chunks, am or voc") + return chunks + + +def get_stream_am_inference(args, am_config): + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + odim = am_config.n_mels + + am_class = dynamic_import(am_name, model_alias) + am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) + am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) + am.eval() + am_mu, am_std = np.load(args.am_stat) + am_mu = paddle.to_tensor(am_mu) + am_std = paddle.to_tensor(am_std) + + return am, am_mu, am_std + + +def init(args): + global sample_rate + # get config + with open(args.am_config) as f: + am_config = CfgNode(yaml.safe_load(f)) + with open(args.voc_config) as f: + voc_config = CfgNode(yaml.safe_load(f)) + + sample_rate = am_config.fs + + # frontend + frontend = get_frontend(args) + + # acoustic model + if args.am_support_stream: + am, am_mu, am_std = get_stream_am_inference(args, am_config) + am_infer_info = [am, am_mu, am_std, am_config] + else: + am_inference, am_name, am_dataset = get_am_inference(args, am_config) + am_infer_info = [am_inference, am_name, am_dataset, am_config] + + # vocoder + voc_inference = get_voc_inference(args, voc_config) + voc_infer_info = [voc_inference, voc_config] + + return frontend, am_infer_info, voc_infer_info + + +def get_phone(args, frontend, sentence, merge_sentences, get_tone_ids): + am_name = args.am[:args.am.rindex('_')] + tone_ids = None + if am_name == 'speedyspeech': + get_tone_ids = True + + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + elif args.lang == 'en': + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + + return phone_ids, tone_ids + + +@paddle.no_grad() +# 生成完整的mel +def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids): + # 如果是支持流式的AM模型 + if args.am_support_stream: + am, am_mu, am_std, am_config = am_infer_info + orig_hs, h_masks = am.encoder_infer(part_phone_ids) + if args.am_streaming: + am_pad = args.am_pad + am_block = args.am_block + hss = get_chunks(orig_hs, am_block, am_pad, "am") + chunk_num = len(hss) + mel_list = [] + for i, hs in enumerate(hss): + before_outs, _ = am.decoder(hs) + after_outs = before_outs + am.postnet( + before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + normalized_mel = after_outs[0] + sub_mel = denorm(normalized_mel, am_mu, am_std) + # clip output part of pad + if i == 0: + sub_mel = sub_mel[:-am_pad] + elif i == chunk_num - 1: + # 最后一块的右侧一定没有 pad 够 + sub_mel = sub_mel[am_pad:] + else: + # 倒数几块的右侧也可能没有 pad 够 + sub_mel = sub_mel[am_pad:(am_block + am_pad) - + sub_mel.shape[0]] + mel_list.append(sub_mel) + mel = paddle.concat(mel_list, axis=0) + + else: + orig_hs, h_masks = am.encoder_infer(part_phone_ids) + before_outs, _ = am.decoder(orig_hs) + after_outs = before_outs + am.postnet( + before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + normalized_mel = after_outs[0] + mel = denorm(normalized_mel, am_mu, am_std) + + else: + am_inference, am_name, am_dataset, am_config = am_infer_info + # acoustic model + if am_name == 'fastspeech2': + # multi speaker + if am_dataset in {"aishell3", "vctk"}: + spk_id = paddle.to_tensor(args.spk_id) + mel = am_inference(part_phone_ids, spk_id) + else: + mel = am_inference(part_phone_ids) + elif am_name == 'speedyspeech': + part_tone_ids = tone_ids[i] + if am_dataset in {"aishell3", "vctk"}: + spk_id = paddle.to_tensor(args.spk_id) + mel = am_inference(part_phone_ids, part_tone_ids, spk_id) + else: + mel = am_inference(part_phone_ids, part_tone_ids) + elif am_name == 'tacotron2': + mel = am_inference(part_phone_ids) + + return mel + + +@paddle.no_grad() +def stream_voc_infer(args, voc_infer_info, mel_len): + global mel_streaming + global stream_first_time + global wav_streaming + voc_inference, voc_config = voc_infer_info + block = args.voc_block + pad = args.voc_pad + upsample = voc_config.n_shift + wav_list = [] + flag = 1 + + valid_start = 0 + valid_end = min(valid_start + block, mel_len) + actual_start = 0 + actual_end = min(valid_end + pad, mel_len) + mel_chunk = mel_streaming[actual_start:actual_end, :] + + while valid_end <= mel_len: + sub_wav = voc_inference(mel_chunk) + if flag == 1: + stream_first_time = time.time() + flag = 0 + + # get valid wav + start = valid_start - actual_start + if valid_end == mel_len: + sub_wav = sub_wav[start * upsample:] + wav_list.append(sub_wav) + break + else: + end = start + block + sub_wav = sub_wav[start * upsample:end * upsample] + wav_list.append(sub_wav) + + # generate new mel chunk + valid_start = valid_end + valid_end = min(valid_start + block, mel_len) + if valid_start - pad < 0: + actual_start = 0 + else: + actual_start = valid_start - pad + actual_end = min(valid_end + pad, mel_len) + mel_chunk = mel_streaming[actual_start:actual_end, :] + + wav = paddle.concat(wav_list, axis=0) + wav_streaming = wav + + +@paddle.no_grad() +# 非流式AM / 流式AM + 非流式Voc +def am_nostream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, + part_tone_ids): + mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids) + am_infer_time = time.time() + voc_inference, voc_config = voc_infer_info + wav = voc_inference(mel) + first_response_time = time.time() + final_response_time = first_response_time + voc_infer_time = first_response_time + + return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav + + +@paddle.no_grad() +# 非流式AM + 流式Voc +def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, + part_tone_ids): + global mel_streaming + global stream_first_time + global wav_streaming + + mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids) + am_infer_time = time.time() + + # voc streaming + mel_streaming = mel + mel_len = mel.shape[0] + stream_voc_infer(args, voc_infer_info, mel_len) + first_response_time = stream_first_time + wav = wav_streaming + final_response_time = time.time() + voc_infer_time = final_response_time + + return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav + + +@paddle.no_grad() +# 流式AM + 流式 Voc +def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, + part_tone_ids): + global mel_streaming + global stream_first_time + global wav_streaming + global voc_stream_st + mel_streaming = None + flag = 1 #用来表示开启流式voc的线程 + + am, am_mu, am_std, am_config = am_infer_info + orig_hs, h_masks = am.encoder_infer(part_phone_ids) + mel_len = orig_hs.shape[1] + am_block = args.am_block + am_pad = args.am_pad + hss = get_chunks(orig_hs, am_block, am_pad, "am") + chunk_num = len(hss) + + for i, hs in enumerate(hss): + before_outs, _ = am.decoder(hs) + after_outs = before_outs + am.postnet( + before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + normalized_mel = after_outs[0] + sub_mel = denorm(normalized_mel, am_mu, am_std) + # clip output part of pad + if i == 0: + sub_mel = sub_mel[:-am_pad] + mel_streaming = sub_mel + elif i == chunk_num - 1: + # 最后一块的右侧一定没有 pad 够 + sub_mel = sub_mel[am_pad:] + mel_streaming = paddle.concat([mel_streaming, sub_mel]) + am_infer_time = time.time() + else: + # 倒数几块的右侧也可能没有 pad 够 + sub_mel = sub_mel[am_pad:(am_block + am_pad) - sub_mel.shape[0]] + mel_streaming = paddle.concat([mel_streaming, sub_mel]) + + if flag and mel_streaming.shape[0] > args.voc_block + args.voc_pad: + t = threading.Thread( + target=stream_voc_infer, args=(args, voc_infer_info, mel_len, )) + t.start() + voc_stream_st = time.time() + flag = 0 + + t.join() + final_response_time = time.time() + voc_infer_time = final_response_time + first_response_time = stream_first_time + wav = wav_streaming + + return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav + + +def try_infer(args, logger, frontend, am_infer_info, voc_infer_info): + global sample_rate + logger.info( + "Before the formal test, we test a few texts to make the inference speed more stable." + ) + if args.lang == 'zh': + sentence = "您好,欢迎使用语音合成服务。" + if args.lang == 'en': + sentence = "Hello and welcome to the speech synthesis service." + + if args.voc_streaming: + if args.am_streaming: + infer_func = stream_am_stream_voc + else: + infer_func = nostream_am_stream_voc + else: + infer_func = am_nostream_voc + + merge_sentences = True + get_tone_ids = False + for i in range(3): # 推理3次 + st = time.time() + phone_ids, tone_ids = get_phone(args, frontend, sentence, + merge_sentences, get_tone_ids) + part_phone_ids = phone_ids[0] + if tone_ids: + part_tone_ids = tone_ids[0] + else: + part_tone_ids = None + + am_infer_time, voc_infer_time, first_response_time, final_response_time, wav = infer_func( + args, am_infer_info, voc_infer_info, part_phone_ids, part_tone_ids) + wav = wav.numpy() + duration = wav.size / sample_rate + logger.info( + f"sentence: {sentence}; duration: {duration} s; first response time: {first_response_time - st} s; final response time: {final_response_time - st} s" + ) + + +def evaluate(args, logger, frontend, am_infer_info, voc_infer_info): + global sample_rate + sentences = get_sentences(args) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + get_tone_ids = False + merge_sentences = True + + # choose infer function + if args.voc_streaming: + if args.am_streaming: + infer_func = stream_am_stream_voc + else: + infer_func = nostream_am_stream_voc + else: + infer_func = am_nostream_voc + + final_up_duration = 0.0 + sentence_count = 0 + front_time_list = [] + am_time_list = [] + voc_time_list = [] + first_response_list = [] + final_response_list = [] + sentence_length_list = [] + duration_list = [] + + for utt_id, sentence in sentences: + # front + front_st = time.time() + phone_ids, tone_ids = get_phone(args, frontend, sentence, + merge_sentences, get_tone_ids) + part_phone_ids = phone_ids[0] + if tone_ids: + part_tone_ids = tone_ids[0] + else: + part_tone_ids = None + front_et = time.time() + front_time = front_et - front_st + + am_st = time.time() + am_infer_time, voc_infer_time, first_response_time, final_response_time, wav = infer_func( + args, am_infer_info, voc_infer_info, part_phone_ids, part_tone_ids) + am_time = am_infer_time - am_st + if args.voc_streaming and args.am_streaming: + voc_time = voc_infer_time - voc_stream_st + else: + voc_time = voc_infer_time - am_infer_time + + first_response = first_response_time - front_st + final_response = final_response_time - front_st + + wav = wav.numpy() + duration = wav.size / sample_rate + sf.write( + str(output_dir / (utt_id + ".wav")), wav, samplerate=sample_rate) + print(f"{utt_id} done!") + + sentence_count += 1 + front_time_list.append(front_time) + am_time_list.append(am_time) + voc_time_list.append(voc_time) + first_response_list.append(first_response) + final_response_list.append(final_response) + sentence_length_list.append(len(sentence)) + duration_list.append(duration) + + logger.info( + f"uttid: {utt_id}; sentence: '{sentence}'; front time: {front_time} s; am time: {am_time} s; voc time: {voc_time} s; \ + first response time: {first_response} s; final response time: {final_response} s; audio duration: {duration} s;" + ) + + if final_response > duration: + final_up_duration += 1 + + all_time_sum = sum(final_response_list) + front_rate = sum(front_time_list) / all_time_sum + am_rate = sum(am_time_list) / all_time_sum + voc_rate = sum(voc_time_list) / all_time_sum + rtf = all_time_sum / sum(duration_list) + + logger.info( + f"The length of test text information, test num: {sentence_count}; text num: {sum(sentence_length_list)}; min: {min(sentence_length_list)}; max: {max(sentence_length_list)}; avg: {sum(sentence_length_list)/len(sentence_length_list)}" + ) + logger.info( + f"duration information, min: {min(duration_list)}; max: {max(duration_list)}; avg: {sum(duration_list) / len(duration_list)}; sum: {sum(duration_list)}" + ) + logger.info( + f"Front time information: min: {min(front_time_list)} s; max: {max(front_time_list)} s; avg: {sum(front_time_list)/len(front_time_list)} s; ratio: {front_rate * 100}%" + ) + logger.info( + f"AM time information: min: {min(am_time_list)} s; max: {max(am_time_list)} s; avg: {sum(am_time_list)/len(am_time_list)} s; ratio: {am_rate * 100}%" + ) + logger.info( + f"Vocoder time information: min: {min(voc_time_list)} s, max: {max(voc_time_list)} s; avg: {sum(voc_time_list)/len(voc_time_list)} s; ratio: {voc_rate * 100}%" + ) + logger.info( + f"first response time information: min: {min(first_response_list)} s; max: {max(first_response_list)} s; avg: {sum(first_response_list)/len(first_response_list)} s" + ) + logger.info( + f"final response time information: min: {min(final_response_list)} s; max: {max(final_response_list)} s; avg: {sum(final_response_list)/len(final_response_list)} s" + ) + logger.info(f"RTF is: {rtf}") + logger.info( + f"The number of final_response is greater than duration is {final_up_duration}, ratio: {final_up_duration / sentence_count}%" + ) + + +def parse_args(): + # parse args and config and redirect to train_sp + parser = argparse.ArgumentParser( + description="Synthesize with acoustic model & vocoder") + # acoustic model + parser.add_argument( + '--am', + type=str, + default='fastspeech2_csmsc', + choices=[ + 'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc', + 'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk', + 'tacotron2_csmsc', 'tacotron2_ljspeech' + ], + help='Choose acoustic model type of tts task.') + parser.add_argument( + '--am_support_stream', + type=str2bool, + default=False, + help='if am model is fastspeech2_csmsc, specify whether it supports streaming' + ) + parser.add_argument( + '--am_config', + type=str, + default=None, + help='Config of acoustic model. Use deault config when it is None.') + parser.add_argument( + '--am_ckpt', + type=str, + default=None, + help='Checkpoint file of acoustic model.') + parser.add_argument( + "--am_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training acoustic model." + ) + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--tones_dict", type=str, default=None, help="tone vocabulary file.") + parser.add_argument( + "--speaker_dict", type=str, default=None, help="speaker id map file.") + parser.add_argument( + '--spk_id', + type=int, + default=0, + help='spk id for multi speaker acoustic model') + # vocoder + parser.add_argument( + '--voc', + type=str, + default='mb_melgan_csmsc', + choices=[ + 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk', + 'mb_melgan_csmsc', 'style_melgan_csmsc', 'hifigan_csmsc', + 'wavernn_csmsc' + ], + help='Choose vocoder type of tts task.') + parser.add_argument( + '--voc_config', + type=str, + default=None, + help='Config of voc. Use deault config when it is None.') + parser.add_argument( + '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') + parser.add_argument( + "--voc_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training voc." + ) + # other + parser.add_argument( + '--lang', + type=str, + default='zh', + choices=['zh', 'en'], + help='Choose model language. zh or en') + + parser.add_argument( + "--device", type=str, default='cpu', help="set cpu or gpu:id") + + parser.add_argument( + "--text", + type=str, + default="./csmsc_test.txt", + help="text to synthesize, a 'utt_id sentence' pair per line.") + parser.add_argument("--output_dir", type=str, help="output dir.") + parser.add_argument( + "--log_file", type=str, default="result.log", help="log file.") + + parser.add_argument( + "--am_streaming", + type=str2bool, + default=False, + help="whether use streaming acoustic model") + + parser.add_argument("--am_pad", type=int, default=12, help="am pad size.") + + parser.add_argument( + "--am_block", type=int, default=42, help="am block size.") + + parser.add_argument( + "--voc_streaming", + type=str2bool, + default=False, + help="whether use streaming vocoder model") + + parser.add_argument("--voc_pad", type=int, default=14, help="voc pad size.") + + parser.add_argument( + "--voc_block", type=int, default=14, help="voc block size.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + paddle.set_device(args.device) + if args.am_support_stream: + assert (args.am == 'fastspeech2_csmsc') + if args.am_streaming: + assert (args.am_support_stream and args.am == 'fastspeech2_csmsc') + if args.voc_streaming: + assert (args.voc == 'mb_melgan_csmsc' or args.voc == 'hifigan_csmsc') + + logger = logging.getLogger() + fhandler = logging.FileHandler(filename=args.log_file, mode='w') + formatter = logging.Formatter( + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + ) + fhandler.setFormatter(formatter) + logger.addHandler(fhandler) + logger.setLevel(logging.DEBUG) + + # set basic information + logger.info( + f"AM: {args.am}; Vocoder: {args.voc}; device: {args.device}; am streaming: {args.am_streaming}; voc streaming: {args.voc_streaming}" + ) + logger.info( + f"am pad size: {args.am_pad}; am block size: {args.am_block}; voc pad size: {args.voc_pad}; voc block size: {args.voc_block};" + ) + + # get information about model + frontend, am_infer_info, voc_infer_info = init(args) + logger.info( + "************************ try infer *********************************") + try_infer(args, logger, frontend, am_infer_info, voc_infer_info) + logger.info( + "************************ normal test *******************************") + evaluate(args, logger, frontend, am_infer_info, voc_infer_info) + + +if __name__ == "__main__": + main() From 4b111146dc959daac319879ba8d89fb9a3f24b75 Mon Sep 17 00:00:00 2001 From: lym0302 Date: Mon, 11 Apr 2022 15:31:03 +0800 Subject: [PATCH 02/31] code format, test=doc --- .../server/tests/tts/infer/csmsc_test.txt | 100 ------------------ paddlespeech/server/tests/tts/infer/run.sh | 28 ++--- .../server/tests/tts/infer/test_online_tts.py | 71 +++---------- 3 files changed, 26 insertions(+), 173 deletions(-) delete mode 100644 paddlespeech/server/tests/tts/infer/csmsc_test.txt diff --git a/paddlespeech/server/tests/tts/infer/csmsc_test.txt b/paddlespeech/server/tests/tts/infer/csmsc_test.txt deleted file mode 100644 index d8cf367c..00000000 --- a/paddlespeech/server/tests/tts/infer/csmsc_test.txt +++ /dev/null @@ -1,100 +0,0 @@ -009901 昨日,这名伤者与医生全部被警方依法刑事拘留。 -009902 钱伟长想到上海来办学校是经过深思熟虑的。 -009903 她见我一进门就骂,吃饭时也骂,骂得我抬不起头。 -009904 李述德在离开之前,只说了一句柱驼杀父亲了。 -009905 这种车票和保险单捆绑出售属于重复性购买。 -009906 戴佩妮的男友西米露接唱情歌,让她非常开心。 -009907 观大势,谋大局,出大策始终是该院的办院方针。 -009908 他们骑着摩托回家,正好为农忙时的父母帮忙。 -009909 但是因为还没到退休年龄,只能掰着指头捱日子。 -009910 这几天雨水不断,人们恨不得待在家里不出门。 -009911 没想到徐赟,张海翔两人就此玩起了人间蒸发。 -009912 藤村此番发言可能是为了凸显野田的领导能力。 -009913 程长庚,生在清王朝嘉庆年间,安徽的潜山小县。 -009914 南海海域综合补给基地码头项目正在论证中。 -009915 也就是说今晚成都市民极有可能再次看到飘雪。 -009916 随着天气转热,各地的游泳场所开始人头攒动。 -009917 更让徐先生纳闷的是,房客的手机也打不通了。 -009918 遇到颠簸时,应听从乘务员的安全指令,回座位坐好。 -009919 他在后面呆惯了,怕自己一插身后的人会不满,不敢排进去。 -009920 傍晚七个小人回来了,白雪公主说,你们就是我命中的七个小矮人吧。 -009921 他本想说,教育局管这个,他们是一路的,这样一管岂不是妓女起嫖客? -009922 一种表示商品所有权的财物证券,也称商品证券,如提货单,交货单。 -009923 会有很丰富的东西留下来,说都说不完。 -009924 这句话像从天而降,吓得四周一片寂静。 -009925 记者所在的是受害人家属所在的右区。 -009926 不管哈大爷去哪,它都一步不离地跟着。 -009927 大家抬头望去,一只老鼠正趴在吊顶上。 -009928 我决定过年就辞职,接手我爸的废品站! -009929 最终,中国男子乒乓球队获得此奖项。 -009930 防汛抗旱两手抓,抗旱相对抓的不够。 -009931 图们江下游地区开发开放的进展如何? -009932 这要求中国必须有一个坚强的政党领导。 -009933 再说,关于利益上的事俺俩都不好开口。 -009934 明代瓦剌,鞑靼入侵明境也是通过此地。 -009935 咪咪舔着孩子,把它身上的毛舔干净。 -009936 是否这次的国标修订被大企业绑架了? -009937 判决后,姚某妻子胡某不服,提起上诉。 -009938 由此可以看出邯钢的经济效益来自何处。 -009939 琳达说,是瑜伽改变了她和马儿的生活。 -009940 楼下的保安告诉记者,这里不租也不卖。 -009941 习近平说,中斯两国人民传统友谊深厚。 -009942 传闻越来越多,后来连老汉儿自己都怕了。 -009943 我怒吼一声冲上去,举起砖头砸了过去。 -009944 我现在还不会,这就回去问问发明我的人。 -009945 显然,洛阳性奴案不具备上述两个前提。 -009946 另外,杰克逊有文唇线,眼线,眉毛的动作。 -009947 昨晚,华西都市报记者电话采访了尹琪。 -009948 涅拉季科未透露这些航空公司的名称。 -009949 从运行轨迹上来说,它也不可能是星星。 -009950 目前看,如果继续加息也存在两难问题。 -009951 曾宝仪在节目录制现场大爆观众糗事。 -009952 但任凭周某怎么叫,男子仍酣睡不醒。 -009953 老大爷说,小子,你挡我财路了,知道不? -009954 没料到,闯下大头佛的阿伟还不知悔改。 -009955 卡扎菲部落式统治已遭遇部落内讧。 -009956 这个孩子的生命一半来源于另一位女士捐赠的冷冻卵子。 -009957 出现这种泥鳅内阁的局面既是野田有意为之,也实属无奈。 -009958 济青高速济南,华山,章丘,邹平,周村,淄博,临淄站。 -009959 赵凌飞的话,反映了沈阳赛区所有奥运志愿者的共同心声。 -009960 因为,我们所发出的力量必会因难度加大而减弱。 -009961 发生事故的楼梯拐角处仍可看到血迹。 -009962 想过进公安,可能身高不够,老汉儿也不让我进去。 -009963 路上关卡很多,为了方便撤离,只好轻装前进。 -009964 原来比尔盖茨就是美国微软公司联合创始人呀。 -009965 之后他们一家三口将与双方父母往峇里岛旅游。 -009966 谢谢总理,也感谢广大网友的参与,我们明年再见。 -009967 事实上是,从来没有一个欺善怕恶的人能作出过稍大一点的成就。 -009968 我会打开邮件,你可以从那里继续。 -009969 美方对近期东海局势表示关切。 -009970 据悉,奥巴马一家人对这座冬季白宫极为满意。 -009971 打扫完你会很有成就感的,试一试,你就信了。 -009972 诺曼站在滑板车上,各就各位,准备出发啦! -009973 塔河的寒夜,气温降到了零下三十多摄氏度。 -009974 其间,连破六点六,六点五,六点四,六点三五等多个重要关口。 -009975 算命其实只是人们的一种自我安慰和自我暗示而已,我们还是要相信科学才好。 -009976 这一切都令人欢欣鼓舞,阿讷西没理由不坚持到最后。 -009977 直至公元前一万一千年,它又再次出现。 -009978 尽量少玩电脑,少看电视,少打游戏。 -009979 从五到七,前后也就是六个月的时间。 -009980 一进咖啡店,他就遇见一张熟悉的脸。 -009981 好在众弟兄看到了把她追了回来。 -009982 有一个人说,哥们儿我们跑过它才能活。 -009983 捅了她以后,模糊记得她没咋动了。 -009984 从小到大,葛启义没有收到过压岁钱。 -009985 舞台下的你会对舞台上的你说什么? -009986 但考生普遍认为,试题的怪多过难。 -009987 我希望每个人都能够尊重我们的隐私。 -009988 漫天的红霞使劲给两人增添气氛。 -009989 晚上加完班开车回家,太累了,迷迷糊糊开着车,走一半的时候,铛一声! -009990 该车将三人撞倒后,在大雾中逃窜。 -009991 这人一哆嗦,方向盘也把不稳了,差点撞上了高速边道护栏。 -009992 那女孩儿委屈的说,我一回头见你已经进去了我不敢进去啊! -009993 小明摇摇头说,不是,我只是美女看多了,想换个口味而已。 -009994 接下来,红娘要求记者交费,记者表示不知表姐身份证号码。 -009995 李东蓊表示,自己当时在法庭上发表了一次独特的公诉意见。 -009996 另一男子扑了上来,手里拿着明晃晃的长刀,向他胸口直刺。 -009997 今天,快递员拿着一个快递在办公室喊,秦王是哪个,有他快递? -009998 这场抗议活动究竟是如何发展演变的,又究竟是谁伤害了谁? -009999 因华国锋肖鸡,墓地设计根据其属相设计。 -010000 在狱中,张明宝悔恨交加,写了一份忏悔书。 diff --git a/paddlespeech/server/tests/tts/infer/run.sh b/paddlespeech/server/tests/tts/infer/run.sh index fdceec41..631daddd 100644 --- a/paddlespeech/server/tests/tts/infer/run.sh +++ b/paddlespeech/server/tests/tts/infer/run.sh @@ -1,14 +1,7 @@ -model_path=/home/users/liangyunming/.paddlespeech/models/ -#am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_nosil_baker_ckpt_0.4/ ## fastspeech2 -am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0/ ## fastspeech2_cnn -voc_model_dir=$model_path/hifigan_csmsc-zh/hifigan_csmsc_ckpt_0.1.1/ ## hifigan -#voc_model_dir=$model_path/mb_melgan_csmsc-zh/mb_melgan_csmsc_ckpt_0.1.1/ ## mb_melgan - -if [[ $am_model_dir == *"fastspeech2_cnndecoder"* ]]; then - am_support_stream=True -else - am_support_stream=False -fi +model_path=~/.paddlespeech/models/ +am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0/ ## fastspeech2_c +voc_model_dir=$model_path/mb_melgan_csmsc-zh/mb_melgan_csmsc_ckpt_0.1.1/ ## mb_melgan +testdata=../../../../t2s/exps/csmsc_test.txt # get am file for file in $(ls $am_model_dir) @@ -39,23 +32,24 @@ do done -#run -python test_online_tts.py --am fastspeech2_csmsc \ - --am_support_stream $am_support_stream \ +# run test +# am can choose fastspeech2_csmsc or fastspeech2-C_csmsc, where fastspeech2-C_csmsc supports streaming inference. +# voc can choose hifigan_csmsc and mb_melgan_csmsc, They can both support streaming inference. +python test_online_tts.py --am fastspeech2-C_csmsc \ --am_config $am_model_dir/$am_config_file \ --am_ckpt $am_model_dir/$am_ckpt_file \ --am_stat $am_model_dir/$am_stat_file \ --phones_dict $am_model_dir/$phones_dict_file \ - --voc hifigan_csmsc \ + --voc mb_melgan_csmsc \ --voc_config $voc_model_dir/$voc_config_file \ --voc_ckpt $voc_model_dir/$voc_ckpt_file \ --voc_stat $voc_model_dir/$voc_stat_file \ --lang zh \ --device cpu \ - --text ./csmsc_test.txt \ + --text $testdata \ --output_dir ./output \ --log_file ./result.log \ - --am_streaming False \ + --am_streaming True \ --am_pad 12 \ --am_block 42 \ --voc_streaming True \ diff --git a/paddlespeech/server/tests/tts/infer/test_online_tts.py b/paddlespeech/server/tests/tts/infer/test_online_tts.py index 17ac0ea7..8ccf724b 100644 --- a/paddlespeech/server/tests/tts/infer/test_online_tts.py +++ b/paddlespeech/server/tests/tts/infer/test_online_tts.py @@ -71,8 +71,7 @@ def get_stream_am_inference(args, am_config): vocab_size = len(phn_id) print("vocab_size:", vocab_size) - am_name = args.am[:args.am.rindex('_')] - am_dataset = args.am[args.am.rindex('_') + 1:] + am_name = "fastspeech2" odim = am_config.n_mels am_class = dynamic_import(am_name, model_alias) @@ -100,7 +99,7 @@ def init(args): frontend = get_frontend(args) # acoustic model - if args.am_support_stream: + if args.am == 'fastspeech2-C_csmsc': am, am_mu, am_std = get_stream_am_inference(args, am_config) am_infer_info = [am, am_mu, am_std, am_config] else: @@ -117,8 +116,6 @@ def init(args): def get_phone(args, frontend, sentence, merge_sentences, get_tone_ids): am_name = args.am[:args.am.rindex('_')] tone_ids = None - if am_name == 'speedyspeech': - get_tone_ids = True if args.lang == 'zh': input_ids = frontend.get_input_ids( @@ -142,7 +139,7 @@ def get_phone(args, frontend, sentence, merge_sentences, get_tone_ids): # 生成完整的mel def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids): # 如果是支持流式的AM模型 - if args.am_support_stream: + if args.am == 'fastspeech2-C_csmsc': am, am_mu, am_std, am_config = am_infer_info orig_hs, h_masks = am.encoder_infer(part_phone_ids) if args.am_streaming: @@ -180,23 +177,7 @@ def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids): else: am_inference, am_name, am_dataset, am_config = am_infer_info - # acoustic model - if am_name == 'fastspeech2': - # multi speaker - if am_dataset in {"aishell3", "vctk"}: - spk_id = paddle.to_tensor(args.spk_id) - mel = am_inference(part_phone_ids, spk_id) - else: - mel = am_inference(part_phone_ids) - elif am_name == 'speedyspeech': - part_tone_ids = tone_ids[i] - if am_dataset in {"aishell3", "vctk"}: - spk_id = paddle.to_tensor(args.spk_id) - mel = am_inference(part_phone_ids, part_tone_ids, spk_id) - else: - mel = am_inference(part_phone_ids, part_tone_ids) - elif am_name == 'tacotron2': - mel = am_inference(part_phone_ids) + mel = am_inference(part_phone_ids) return mel @@ -297,7 +278,8 @@ def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, global wav_streaming global voc_stream_st mel_streaming = None - flag = 1 #用来表示开启流式voc的线程 + #用来表示开启流式voc的线程 + flag = 1 am, am_mu, am_std, am_config = am_infer_info orig_hs, h_masks = am.encoder_infer(part_phone_ids) @@ -343,7 +325,7 @@ def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav -def try_infer(args, logger, frontend, am_infer_info, voc_infer_info): +def warm_up(args, logger, frontend, am_infer_info, voc_infer_info): global sample_rate logger.info( "Before the formal test, we test a few texts to make the inference speed more stable." @@ -363,7 +345,7 @@ def try_infer(args, logger, frontend, am_infer_info, voc_infer_info): merge_sentences = True get_tone_ids = False - for i in range(3): # 推理3次 + for i in range(5): # 推理5次 st = time.time() phone_ids, tone_ids = get_phone(args, frontend, sentence, merge_sentences, get_tone_ids) @@ -500,18 +482,10 @@ def parse_args(): '--am', type=str, default='fastspeech2_csmsc', - choices=[ - 'speedyspeech_csmsc', 'speedyspeech_aishell3', 'fastspeech2_csmsc', - 'fastspeech2_ljspeech', 'fastspeech2_aishell3', 'fastspeech2_vctk', - 'tacotron2_csmsc', 'tacotron2_ljspeech' - ], - help='Choose acoustic model type of tts task.') - parser.add_argument( - '--am_support_stream', - type=str2bool, - default=False, - help='if am model is fastspeech2_csmsc, specify whether it supports streaming' + choices=['fastspeech2_csmsc', 'fastspeech2-C_csmsc'], + help='Choose acoustic model type of tts task. where fastspeech2-C_csmsc supports streaming inference' ) + parser.add_argument( '--am_config', type=str, @@ -532,23 +506,12 @@ def parse_args(): "--phones_dict", type=str, default=None, help="phone vocabulary file.") parser.add_argument( "--tones_dict", type=str, default=None, help="tone vocabulary file.") - parser.add_argument( - "--speaker_dict", type=str, default=None, help="speaker id map file.") - parser.add_argument( - '--spk_id', - type=int, - default=0, - help='spk id for multi speaker acoustic model') # vocoder parser.add_argument( '--voc', type=str, default='mb_melgan_csmsc', - choices=[ - 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk', - 'mb_melgan_csmsc', 'style_melgan_csmsc', 'hifigan_csmsc', - 'wavernn_csmsc' - ], + choices=['mb_melgan_csmsc', 'hifigan_csmsc'], help='Choose vocoder type of tts task.') parser.add_argument( '--voc_config', @@ -612,12 +575,8 @@ def parse_args(): def main(): args = parse_args() paddle.set_device(args.device) - if args.am_support_stream: - assert (args.am == 'fastspeech2_csmsc') if args.am_streaming: - assert (args.am_support_stream and args.am == 'fastspeech2_csmsc') - if args.voc_streaming: - assert (args.voc == 'mb_melgan_csmsc' or args.voc == 'hifigan_csmsc') + assert (args.am == 'fastspeech2-C_csmsc') logger = logging.getLogger() fhandler = logging.FileHandler(filename=args.log_file, mode='w') @@ -639,8 +598,8 @@ def main(): # get information about model frontend, am_infer_info, voc_infer_info = init(args) logger.info( - "************************ try infer *********************************") - try_infer(args, logger, frontend, am_infer_info, voc_infer_info) + "************************ warm up *********************************") + warm_up(args, logger, frontend, am_infer_info, voc_infer_info) logger.info( "************************ normal test *******************************") evaluate(args, logger, frontend, am_infer_info, voc_infer_info) From 9d0224460bec81139fd7d69732dce0f7c7ec36fa Mon Sep 17 00:00:00 2001 From: lym0302 Date: Mon, 11 Apr 2022 15:54:44 +0800 Subject: [PATCH 03/31] code format, test=doc --- paddlespeech/server/tests/tts/infer/run.sh | 12 ++-- .../server/tests/tts/infer/test_online_tts.py | 67 ++++++++++--------- 2 files changed, 42 insertions(+), 37 deletions(-) diff --git a/paddlespeech/server/tests/tts/infer/run.sh b/paddlespeech/server/tests/tts/infer/run.sh index 631daddd..3733c3fb 100644 --- a/paddlespeech/server/tests/tts/infer/run.sh +++ b/paddlespeech/server/tests/tts/infer/run.sh @@ -1,6 +1,6 @@ model_path=~/.paddlespeech/models/ -am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0/ ## fastspeech2_c -voc_model_dir=$model_path/mb_melgan_csmsc-zh/mb_melgan_csmsc_ckpt_0.1.1/ ## mb_melgan +am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0/ +voc_model_dir=$model_path/mb_melgan_csmsc-zh/mb_melgan_csmsc_ckpt_0.1.1/ testdata=../../../../t2s/exps/csmsc_test.txt # get am file @@ -33,9 +33,13 @@ done # run test -# am can choose fastspeech2_csmsc or fastspeech2-C_csmsc, where fastspeech2-C_csmsc supports streaming inference. +# am can choose fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc, where fastspeech2_cnndecoder_csmsc supports streaming inference. # voc can choose hifigan_csmsc and mb_melgan_csmsc, They can both support streaming inference. -python test_online_tts.py --am fastspeech2-C_csmsc \ +# When am is fastspeech2_cnndecoder_csmsc and am_pad is set to 12, there is no diff between streaming and non-streaming inference results. +# When voc is mb_melgan_csmsc and voc_pad is set to 14, there is no diff between streaming and non-streaming inference results. +# When voc is hifigan_csmsc and voc_pad is set to 20, there is no diff between streaming and non-streaming inference results. + +python test_online_tts.py --am fastspeech2_cnndecoder_csmsc \ --am_config $am_model_dir/$am_config_file \ --am_ckpt $am_model_dir/$am_ckpt_file \ --am_stat $am_model_dir/$am_stat_file \ diff --git a/paddlespeech/server/tests/tts/infer/test_online_tts.py b/paddlespeech/server/tests/tts/infer/test_online_tts.py index 8ccf724b..eb5fc80b 100644 --- a/paddlespeech/server/tests/tts/infer/test_online_tts.py +++ b/paddlespeech/server/tests/tts/infer/test_online_tts.py @@ -34,8 +34,8 @@ from paddlespeech.t2s.utils import str2bool mel_streaming = None wav_streaming = None -stream_first_time = 0.0 -voc_stream_st = 0.0 +streaming_first_time = 0.0 +streaming_voc_st = 0.0 sample_rate = 0 @@ -65,7 +65,7 @@ def get_chunks(data, block_size, pad_size, step): return chunks -def get_stream_am_inference(args, am_config): +def get_streaming_am_inference(args, am_config): with open(args.phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] vocab_size = len(phn_id) @@ -99,8 +99,8 @@ def init(args): frontend = get_frontend(args) # acoustic model - if args.am == 'fastspeech2-C_csmsc': - am, am_mu, am_std = get_stream_am_inference(args, am_config) + if args.am == 'fastspeech2_cnndecoder_csmsc': + am, am_mu, am_std = get_streaming_am_inference(args, am_config) am_infer_info = [am, am_mu, am_std, am_config] else: am_inference, am_name, am_dataset = get_am_inference(args, am_config) @@ -139,7 +139,7 @@ def get_phone(args, frontend, sentence, merge_sentences, get_tone_ids): # 生成完整的mel def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids): # 如果是支持流式的AM模型 - if args.am == 'fastspeech2-C_csmsc': + if args.am == 'fastspeech2_cnndecoder_csmsc': am, am_mu, am_std, am_config = am_infer_info orig_hs, h_masks = am.encoder_infer(part_phone_ids) if args.am_streaming: @@ -183,9 +183,9 @@ def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids): @paddle.no_grad() -def stream_voc_infer(args, voc_infer_info, mel_len): +def streaming_voc_infer(args, voc_infer_info, mel_len): global mel_streaming - global stream_first_time + global streaming_first_time global wav_streaming voc_inference, voc_config = voc_infer_info block = args.voc_block @@ -203,7 +203,7 @@ def stream_voc_infer(args, voc_infer_info, mel_len): while valid_end <= mel_len: sub_wav = voc_inference(mel_chunk) if flag == 1: - stream_first_time = time.time() + streaming_first_time = time.time() flag = 0 # get valid wav @@ -233,8 +233,8 @@ def stream_voc_infer(args, voc_infer_info, mel_len): @paddle.no_grad() # 非流式AM / 流式AM + 非流式Voc -def am_nostream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, - part_tone_ids): +def am_nonstreaming_voc(args, am_infer_info, voc_infer_info, part_phone_ids, + part_tone_ids): mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids) am_infer_time = time.time() voc_inference, voc_config = voc_infer_info @@ -248,10 +248,10 @@ def am_nostream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, @paddle.no_grad() # 非流式AM + 流式Voc -def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, - part_tone_ids): +def nonstreaming_am_streaming_voc(args, am_infer_info, voc_infer_info, + part_phone_ids, part_tone_ids): global mel_streaming - global stream_first_time + global streaming_first_time global wav_streaming mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids) @@ -260,8 +260,8 @@ def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, # voc streaming mel_streaming = mel mel_len = mel.shape[0] - stream_voc_infer(args, voc_infer_info, mel_len) - first_response_time = stream_first_time + streaming_voc_infer(args, voc_infer_info, mel_len) + first_response_time = streaming_first_time wav = wav_streaming final_response_time = time.time() voc_infer_time = final_response_time @@ -271,12 +271,12 @@ def nostream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, @paddle.no_grad() # 流式AM + 流式 Voc -def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, - part_tone_ids): +def streaming_am_streaming_voc(args, am_infer_info, voc_infer_info, + part_phone_ids, part_tone_ids): global mel_streaming - global stream_first_time + global streaming_first_time global wav_streaming - global voc_stream_st + global streaming_voc_st mel_streaming = None #用来表示开启流式voc的线程 flag = 1 @@ -311,15 +311,16 @@ def stream_am_stream_voc(args, am_infer_info, voc_infer_info, part_phone_ids, if flag and mel_streaming.shape[0] > args.voc_block + args.voc_pad: t = threading.Thread( - target=stream_voc_infer, args=(args, voc_infer_info, mel_len, )) + target=streaming_voc_infer, + args=(args, voc_infer_info, mel_len, )) t.start() - voc_stream_st = time.time() + streaming_voc_st = time.time() flag = 0 t.join() final_response_time = time.time() voc_infer_time = final_response_time - first_response_time = stream_first_time + first_response_time = streaming_first_time wav = wav_streaming return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav @@ -337,11 +338,11 @@ def warm_up(args, logger, frontend, am_infer_info, voc_infer_info): if args.voc_streaming: if args.am_streaming: - infer_func = stream_am_stream_voc + infer_func = streaming_am_streaming_voc else: - infer_func = nostream_am_stream_voc + infer_func = nonstreaming_am_streaming_voc else: - infer_func = am_nostream_voc + infer_func = am_nonstreaming_voc merge_sentences = True get_tone_ids = False @@ -376,11 +377,11 @@ def evaluate(args, logger, frontend, am_infer_info, voc_infer_info): # choose infer function if args.voc_streaming: if args.am_streaming: - infer_func = stream_am_stream_voc + infer_func = streaming_am_streaming_voc else: - infer_func = nostream_am_stream_voc + infer_func = nonstreaming_am_streaming_voc else: - infer_func = am_nostream_voc + infer_func = am_nonstreaming_voc final_up_duration = 0.0 sentence_count = 0 @@ -410,7 +411,7 @@ def evaluate(args, logger, frontend, am_infer_info, voc_infer_info): args, am_infer_info, voc_infer_info, part_phone_ids, part_tone_ids) am_time = am_infer_time - am_st if args.voc_streaming and args.am_streaming: - voc_time = voc_infer_time - voc_stream_st + voc_time = voc_infer_time - streaming_voc_st else: voc_time = voc_infer_time - am_infer_time @@ -482,8 +483,8 @@ def parse_args(): '--am', type=str, default='fastspeech2_csmsc', - choices=['fastspeech2_csmsc', 'fastspeech2-C_csmsc'], - help='Choose acoustic model type of tts task. where fastspeech2-C_csmsc supports streaming inference' + choices=['fastspeech2_csmsc', 'fastspeech2_cnndecoder_csmsc'], + help='Choose acoustic model type of tts task. where fastspeech2_cnndecoder_csmsc supports streaming inference' ) parser.add_argument( @@ -576,7 +577,7 @@ def main(): args = parse_args() paddle.set_device(args.device) if args.am_streaming: - assert (args.am == 'fastspeech2-C_csmsc') + assert (args.am == 'fastspeech2_cnndecoder_csmsc') logger = logging.getLogger() fhandler = logging.FileHandler(filename=args.log_file, mode='w') From af484fc980e9df51e6411a13d9d280a6447f0c26 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Thu, 14 Apr 2022 19:57:52 +0800 Subject: [PATCH 04/31] convert websockert results to str from bytest, test=doc --- .../server/engine/asr/online/asr_engine.py | 23 +++++++---- .../tests/asr/online/websocket_client.py | 40 +++++++++++++++---- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index ca82b615..cd5300fc 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -35,9 +35,9 @@ __all__ = ['ASREngine'] pretrained_models = { "deepspeech2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz', 'md5': - 'd5e076217cf60486519f72c217d21b9b', + '23e16c69730a1cb5d735c98c83c21e16', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -75,6 +75,7 @@ class ASRServerExecutor(ASRExecutor): if cfg_path is None or am_model is None or am_params is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str + logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path self.cfg_path = os.path.join(res_path, @@ -85,9 +86,6 @@ class ASRServerExecutor(ASRExecutor): self.am_params = os.path.join(res_path, pretrained_models[tag]['params']) logger.info(res_path) - logger.info(self.cfg_path) - logger.info(self.am_model) - logger.info(self.am_params) else: self.cfg_path = os.path.abspath(cfg_path) self.am_model = os.path.abspath(am_model) @@ -95,6 +93,10 @@ class ASRServerExecutor(ASRExecutor): self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) + logger.info(self.cfg_path) + logger.info(self.am_model) + logger.info(self.am_params) + #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) @@ -112,15 +114,20 @@ class ASRServerExecutor(ASRExecutor): lm_url = pretrained_models[tag]['lm_url'] lm_md5 = pretrained_models[tag]['lm_md5'] + logger.info(f"Start to load language model {lm_url}") self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: - raise Exception("wrong type") + # 开发 conformer 的流式模型 + logger.info("start to create the stream conformer asr engine") + # 复用cli里面的代码 + else: raise Exception("wrong type") # AM predictor + logger.info("ASR engine start to init the am predictor") self.am_predictor_conf = am_predictor_conf self.am_predictor = init_predictor( model_file=self.am_model, @@ -128,6 +135,7 @@ class ASRServerExecutor(ASRExecutor): predictor_conf=self.am_predictor_conf) # decoder + logger.info("ASR engine start to create the ctc decoder instance") self.decoder = CTCDecoder( odim=self.config.output_dim, # is in vocab enc_n_units=self.config.rnn_layer_size * 2, @@ -138,6 +146,7 @@ class ASRServerExecutor(ASRExecutor): grad_norm_type=self.config.get('ctc_grad_norm_type', None)) # init decoder + logger.info("ASR engine start to init the ctc decoder") cfg = self.config.decode decode_batch_size = 1 # for online self.decoder.init_decoder( @@ -215,7 +224,6 @@ class ASRServerExecutor(ASRExecutor): self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() - return trans_best[0] elif "conformer" in model_type or "transformer" in model_type: @@ -273,6 +281,7 @@ class ASREngine(BaseEngine): def __init__(self): super(ASREngine, self).__init__() + logger.info("create the online asr engine instache") def init(self, config: dict) -> bool: """init engine resource diff --git a/paddlespeech/server/tests/asr/online/websocket_client.py b/paddlespeech/server/tests/asr/online/websocket_client.py index 58b1a452..049d707e 100644 --- a/paddlespeech/server/tests/asr/online/websocket_client.py +++ b/paddlespeech/server/tests/asr/online/websocket_client.py @@ -15,8 +15,10 @@ # -*- coding: UTF-8 -*- import argparse import asyncio +import codecs import json import logging +import os import numpy as np import soundfile @@ -54,12 +56,11 @@ class ASRAudioHandler: async def run(self, wavfile_path: str): logging.info("send a message to the server") - # 读取音频 # self.read_wave() - # 发送 websocket 的 handshake 协议头 + # send websocket handshake protocal async with websockets.connect(self.url) as ws: - # server 端已经接收到 handshake 协议头 - # 发送开始指令 + # server has already received handshake protocal + # client start to send the command audio_info = json.dumps( { "name": "test.wav", @@ -77,8 +78,9 @@ class ASRAudioHandler: for chunk_data in self.read_wave(wavfile_path): await ws.send(chunk_data.tobytes()) msg = await ws.recv() + msg = json.loads(msg) logging.info("receive msg={}".format(msg)) - + result = msg # finished audio_info = json.dumps( { @@ -91,16 +93,36 @@ class ASRAudioHandler: separators=(',', ': ')) await ws.send(audio_info) msg = await ws.recv() + # decode the bytes to str + msg = json.loads(msg) logging.info("receive msg={}".format(msg)) + return result + def main(args): logging.basicConfig(level=logging.INFO) logging.info("asr websocket client start") - handler = ASRAudioHandler("127.0.0.1", 8091) + handler = ASRAudioHandler("127.0.0.1", 8090) loop = asyncio.get_event_loop() - loop.run_until_complete(handler.run(args.wavfile)) - logging.info("asr websocket client finished") + + # support to process single audio file + if args.wavfile and os.path.exists(args.wavfile): + logging.info(f"start to process the wavscp: {args.wavfile}") + result = loop.run_until_complete(handler.run(args.wavfile)) + result = result["asr_results"] + logging.info(f"asr websocket client finished : {result}") + + # support to process batch audios from wav.scp + if args.wavscp and os.path.exists(args.wavscp): + logging.info(f"start to process the wavscp: {args.wavscp}") + with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\ + codecs.open("result.txt", 'w', encoding='utf-8') as w: + for line in f: + utt_name, utt_path = line.strip().split() + result = loop.run_until_complete(handler.run(utt_path)) + result = result["asr_results"] + w.write(f"{utt_name} {result}\n") if __name__ == "__main__": @@ -110,6 +132,8 @@ if __name__ == "__main__": action="store", help="wav file path ", default="./16_audio.wav") + parser.add_argument( + "--wavscp", type=str, default=None, help="The batch audios dict text") args = parser.parse_args() main(args) From d21ccd02875fea5d8c90483a31cd8b6f4a148d2e Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Fri, 15 Apr 2022 18:42:46 +0800 Subject: [PATCH 05/31] add conformer online server, test=doc --- paddlespeech/cli/asr/infer.py | 56 +++-- paddlespeech/s2t/models/u2/u2.py | 8 +- paddlespeech/s2t/modules/ctc.py | 3 +- paddlespeech/s2t/modules/encoder.py | 2 + paddlespeech/server/conf/ws_application.yaml | 54 +++- .../server/engine/asr/online/asr_engine.py | 231 ++++++++++++------ .../tests/asr/online/websocket_client.py | 2 +- paddlespeech/server/ws/asr_socket.py | 29 ++- 8 files changed, 272 insertions(+), 113 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index b12b9f6f..53f71a70 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -91,6 +91,20 @@ pretrained_models = { 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' + }, + "conformer2online_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz', + 'md5': + '4814e52e0fc2fd48899373f95c84b0c9', + 'cfg_path': + 'config.yaml', + 'ckpt_path': + 'exp/deepspeech2_online/checkpoints/avg_30', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' }, "deepspeech2offline_librispeech-en-16k": { 'url': @@ -115,6 +129,8 @@ model_alias = { "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "conformer": "paddlespeech.s2t.models.u2:U2Model", + "conformer2online": + "paddlespeech.s2t.models.u2:U2Model", "transformer": "paddlespeech.s2t.models.u2:U2Model", "wenetspeech": @@ -219,6 +235,7 @@ class ASRExecutor(BaseExecutor): """ Init model and other resources from a specific path. """ + logger.info("start to init the model") if hasattr(self, 'model'): logger.info('Model had been initialized.') return @@ -233,14 +250,15 @@ class ASRExecutor(BaseExecutor): self.ckpt_path = os.path.join( res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams") logger.info(res_path) - logger.info(self.cfg_path) - logger.info(self.ckpt_path) + else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) - + logger.info(self.cfg_path) + logger.info(self.ckpt_path) + #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) @@ -269,7 +287,6 @@ class ASRExecutor(BaseExecutor): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) self.config.decode.decoding_method = decode_method - else: raise Exception("wrong type") model_name = model_type[:model_type.rindex( @@ -347,12 +364,14 @@ class ASRExecutor(BaseExecutor): else: raise Exception("wrong type") + logger.info("audio feat process success") + @paddle.no_grad() def infer(self, model_type: str): """ Model inference and result stored in self.output. """ - + logger.info("start to infer the model to get the output") cfg = self.config.decode audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] @@ -369,17 +388,22 @@ class ASRExecutor(BaseExecutor): self._outputs["result"] = result_transcripts[0] elif "conformer" in model_type or "transformer" in model_type: - result_transcripts = self.model.decode( - audio, - audio_len, - text_feature=self.text_feature, - decoding_method=cfg.decoding_method, - beam_size=cfg.beam_size, - ctc_weight=cfg.ctc_weight, - decoding_chunk_size=cfg.decoding_chunk_size, - num_decoding_left_chunks=cfg.num_decoding_left_chunks, - simulate_streaming=cfg.simulate_streaming) - self._outputs["result"] = result_transcripts[0][0] + logger.info(f"we will use the transformer like model : {model_type}") + try: + result_transcripts = self.model.decode( + audio, + audio_len, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + beam_size=cfg.beam_size, + ctc_weight=cfg.ctc_weight, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + self._outputs["result"] = result_transcripts[0][0] + except Exception as e: + logger.exception(e) + else: raise Exception("invalid model name") diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 6a98607b..f0d2711d 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -213,12 +213,14 @@ class U2BaseModel(ASRInterface, nn.Layer): num_decoding_left_chunks=num_decoding_left_chunks ) # (B, maxlen, encoder_dim) else: + print("offline decode from the asr") encoder_out, encoder_mask = self.encoder( speech, speech_lengths, decoding_chunk_size=decoding_chunk_size, num_decoding_left_chunks=num_decoding_left_chunks ) # (B, maxlen, encoder_dim) + print("offline decode success") return encoder_out, encoder_mask def recognize( @@ -706,13 +708,15 @@ class U2BaseModel(ASRInterface, nn.Layer): List[List[int]]: transcripts. """ batch_size = feats.shape[0] + print("start to decode the audio feat") if decoding_method in ['ctc_prefix_beam_search', 'attention_rescoring'] and batch_size > 1: - logger.fatal( + logger.error( f'decoding mode {decoding_method} must be running with batch_size == 1' ) + logger.error(f"current batch_size is {batch_size}") sys.exit(1) - + print(f"use the {decoding_method} to decode the audio feat") if decoding_method == 'attention': hyps = self.recognize( feats, diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index 33ad472d..bd1219b1 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -180,7 +180,8 @@ class CTCDecoder(CTCDecoderBase): # init once if self._ext_scorer is not None: return - + + from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401 if language_model_path != '': logger.info("begin to initialize the external scorer " "for decoding") diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index c843c0e2..347035cd 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -317,6 +317,8 @@ class BaseEncoder(nn.Layer): outputs = [] offset = 0 # Feed forward overlap input step by step + print(f"context: {context}") + print(f"stride: {stride}") for cur in range(0, num_frames - context + 1, stride): end = min(cur + decoding_window, num_frames) chunk_xs = xs[:, cur:end, :] diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index ef23593e..6b82edcb 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -4,7 +4,7 @@ # SERVER SETTING # ################################################################################# host: 0.0.0.0 -port: 8091 +port: 8096 # The task format in the engin_list is: _ # task choices = ['asr_online', 'tts_online'] @@ -18,10 +18,44 @@ engine_list: ['asr_online'] # ENGINE CONFIG # ################################################################################# +# ################################### ASR ######################################### +# ################### speech task: asr; engine_type: online ####################### +# asr_online: +# model_type: 'deepspeech2online_aishell' +# am_model: # the pdmodel file of am static model [optional] +# am_params: # the pdiparams file of am static model [optional] +# lang: 'zh' +# sample_rate: 16000 +# cfg_path: +# decode_method: +# force_yes: True + +# am_predictor_conf: +# device: # set 'gpu:id' or 'cpu' +# switch_ir_optim: True +# glog_info: False # True -> print glog +# summary: True # False -> do not show predictor config + +# chunk_buffer_conf: +# frame_duration_ms: 80 +# shift_ms: 40 +# sample_rate: 16000 +# sample_width: 2 + +# vad_conf: +# aggressiveness: 2 +# sample_rate: 16000 +# frame_duration_ms: 20 +# sample_width: 2 +# padding_ms: 200 +# padding_ratio: 0.9 + + + ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### asr_online: - model_type: 'deepspeech2online_aishell' + model_type: 'conformer2online_aishell' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' @@ -37,15 +71,15 @@ asr_online: summary: True # False -> do not show predictor config chunk_buffer_conf: - frame_duration_ms: 80 + frame_duration_ms: 85 shift_ms: 40 sample_rate: 16000 sample_width: 2 - vad_conf: - aggressiveness: 2 - sample_rate: 16000 - frame_duration_ms: 20 - sample_width: 2 - padding_ms: 200 - padding_ratio: 0.9 + # vad_conf: + # aggressiveness: 2 + # sample_rate: 16000 + # frame_duration_ms: 20 + # sample_width: 2 + # padding_ms: 200 + # padding_ratio: 0.9 \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index cd5300fc..a5b9ab48 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -20,11 +20,15 @@ from numpy import float32 from yacs.config import CfgNode from paddlespeech.cli.asr.infer import ASRExecutor +from paddlespeech.cli.asr.infer import model_alias +from paddlespeech.cli.asr.infer import pretrained_models from paddlespeech.cli.log import logger from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder +from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import pcm2float @@ -51,6 +55,24 @@ pretrained_models = { 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' }, + "conformer2online_aishell-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz', + 'md5': + '4814e52e0fc2fd48899373f95c84b0c9', + 'cfg_path': + 'exp/chunk_conformer//conf/config.yaml', + 'ckpt_path': + 'exp/chunk_conformer/checkpoints/avg_30/', + 'model': + 'exp/chunk_conformer/checkpoints/avg_30.pdparams', + 'params': + 'exp/chunk_conformer/checkpoints/avg_30.pdparams', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, } @@ -71,15 +93,17 @@ class ASRServerExecutor(ASRExecutor): """ Init model and other resources from a specific path. """ - + self.model_type = model_type + self.sample_rate = sample_rate if cfg_path is None or am_model is None or am_params is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - self.cfg_path = os.path.join(res_path, - pretrained_models[tag]['cfg_path']) + self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/paddlespeech/server/tests/asr/online/conf/config.yaml" + # self.cfg_path = os.path.join(res_path, + # pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, pretrained_models[tag]['model']) @@ -119,49 +143,67 @@ class ASRServerExecutor(ASRExecutor): lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: - # 开发 conformer 的流式模型 logger.info("start to create the stream conformer asr engine") - # 复用cli里面的代码 - + if self.config.spm_model_prefix: + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) + self.config.vocab_filepath = os.path.join( + self.res_path, self.config.vocab_filepath) + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + # update the decoding method + if decode_method: + self.config.decode.decoding_method = decode_method else: raise Exception("wrong type") - - # AM predictor - logger.info("ASR engine start to init the am predictor") - self.am_predictor_conf = am_predictor_conf - self.am_predictor = init_predictor( - model_file=self.am_model, - params_file=self.am_params, - predictor_conf=self.am_predictor_conf) - - # decoder - logger.info("ASR engine start to create the ctc decoder instance") - self.decoder = CTCDecoder( - odim=self.config.output_dim, # is in vocab - enc_n_units=self.config.rnn_layer_size * 2, - blank_id=self.config.blank_id, - dropout_rate=0.0, - reduction=True, # sum - batch_average=True, # sum / batch_size - grad_norm_type=self.config.get('ctc_grad_norm_type', None)) - - # init decoder - logger.info("ASR engine start to init the ctc decoder") - cfg = self.config.decode - decode_batch_size = 1 # for online - self.decoder.init_decoder( - decode_batch_size, self.text_feature.vocab_list, - cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, - cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, - cfg.num_proc_bsearch) - - # init state box - self.chunk_state_h_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - self.chunk_state_c_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + # AM predictor + logger.info("ASR engine start to init the am predictor") + self.am_predictor_conf = am_predictor_conf + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) + + # decoder + logger.info("ASR engine start to create the ctc decoder instance") + self.decoder = CTCDecoder( + odim=self.config.output_dim, # is in vocab + enc_n_units=self.config.rnn_layer_size * 2, + blank_id=self.config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.config.get('ctc_grad_norm_type', None)) + + # init decoder + logger.info("ASR engine start to init the ctc decoder") + cfg = self.config.decode + decode_batch_size = 1 # for online + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + + # init state box + self.chunk_state_h_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + model_name = model_type[:model_type.rindex( + '_')] # model_type: {model_name}_{dataset} + logger.info(f"model name: {model_name}") + model_class = dynamic_import(model_name, model_alias) + model_conf = self.config + model = model_class.from_config(model_conf) + self.model = model + logger.info("create the transformer like model success") def reset_decoder_and_chunk(self): """reset decoder and chunk state for an new audio @@ -186,6 +228,7 @@ class ASRServerExecutor(ASRExecutor): Returns: [type]: [description] """ + logger.info("start to decoce chunk by chunk") if "deepspeech2online" in model_type: input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) @@ -224,10 +267,29 @@ class ASRServerExecutor(ASRExecutor): self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() + logger.info(f"decode one one best result: {trans_best[0]}") return trans_best[0] elif "conformer" in model_type or "transformer" in model_type: - raise Exception("invalid model name") + try: + logger.info( + f"we will use the transformer like model : {self.model_type}" + ) + cfg = self.config.decode + result_transcripts = self.model.decode( + x_chunk, + x_chunk_lens, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + beam_size=cfg.beam_size, + ctc_weight=cfg.ctc_weight, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + + return result_transcripts[0][0] + except Exception as e: + logger.exception(e) else: raise Exception("invalid model name") @@ -244,32 +306,55 @@ class ASRServerExecutor(ASRExecutor): """ # pcm16 -> pcm 32 samples = pcm2float(samples) - - # read audio - speech_segment = SpeechSegment.from_pcm( - samples, sample_rate, transcript=" ") - # audio augment - self.collate_fn_test.augmentation.transform_audio(speech_segment) - - # extract speech feature - spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( - speech_segment, self.collate_fn_test.keep_transcription_text) - # CMVN spectrum - if self.collate_fn_test._normalizer: - spectrum = self.collate_fn_test._normalizer.apply(spectrum) - - # spectrum augment - audio = self.collate_fn_test.augmentation.transform_feature(spectrum) - - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - # audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) - - x_chunk = audio.numpy() - x_chunk_lens = np.array([audio_len]) - - return x_chunk, x_chunk_lens + if "deepspeech2online" in self.model_type: + # read audio + speech_segment = SpeechSegment.from_pcm( + samples, sample_rate, transcript=" ") + # audio augment + self.collate_fn_test.augmentation.transform_audio(speech_segment) + + # extract speech feature + spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( + speech_segment, self.collate_fn_test.keep_transcription_text) + # CMVN spectrum + if self.collate_fn_test._normalizer: + spectrum = self.collate_fn_test._normalizer.apply(spectrum) + + # spectrum augment + audio = self.collate_fn_test.augmentation.transform_feature( + spectrum) + + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + # audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + + x_chunk = audio.numpy() + x_chunk_lens = np.array([audio_len]) + + return x_chunk, x_chunk_lens + elif "conformer2online" in self.model_type: + + if sample_rate != self.sample_rate: + logger.info(f"audio sample rate {sample_rate} is not match," \ + "the model sample_rate is {self.sample_rate}") + logger.info(f"ASR Engine use the {self.model_type} to process") + logger.info("Create the preprocess instance") + preprocess_conf = self.config.preprocess_config + preprocess_args = {"train": False} + preprocessing = Transformation(preprocess_conf) + + logger.info("Read the audio file") + logger.info(f"audio shape: {samples.shape}") + # fbank + x_chunk = preprocessing(samples, **preprocess_args) + x_chunk_lens = paddle.to_tensor(x_chunk.shape[0]) + x_chunk = paddle.to_tensor( + x_chunk, dtype="float32").unsqueeze(axis=0) + logger.info( + f"process the audio feature success, feat shape: {x_chunk.shape}" + ) + return x_chunk, x_chunk_lens class ASREngine(BaseEngine): @@ -310,7 +395,10 @@ class ASREngine(BaseEngine): logger.info("Initialize ASR server engine successfully.") return True - def preprocess(self, samples, sample_rate): + def preprocess(self, + samples, + sample_rate, + model_type="deepspeech2online_aishell-zh-16k"): """preprocess Args: @@ -321,6 +409,7 @@ class ASREngine(BaseEngine): x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] """ + # if "deepspeech" in model_type: x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate) return x_chunk, x_chunk_lens diff --git a/paddlespeech/server/tests/asr/online/websocket_client.py b/paddlespeech/server/tests/asr/online/websocket_client.py index 049d707e..a26838f8 100644 --- a/paddlespeech/server/tests/asr/online/websocket_client.py +++ b/paddlespeech/server/tests/asr/online/websocket_client.py @@ -103,7 +103,7 @@ class ASRAudioHandler: def main(args): logging.basicConfig(level=logging.INFO) logging.info("asr websocket client start") - handler = ASRAudioHandler("127.0.0.1", 8090) + handler = ASRAudioHandler("127.0.0.1", 8096) loop = asyncio.get_event_loop() # support to process single audio file diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index ea19816b..442f26cb 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -14,6 +14,7 @@ import json import numpy as np +import json from fastapi import APIRouter from fastapi import WebSocket from fastapi import WebSocketDisconnect @@ -28,7 +29,7 @@ router = APIRouter() @router.websocket('/ws/asr') async def websocket_endpoint(websocket: WebSocket): - + print("websocket protocal receive the dataset") await websocket.accept() engine_pool = get_engine_pool() @@ -36,14 +37,18 @@ async def websocket_endpoint(websocket: WebSocket): # init buffer chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer = ChunkBuffer( + frame_duration_ms=chunk_buffer_conf['frame_duration_ms'], sample_rate=chunk_buffer_conf['sample_rate'], sample_width=chunk_buffer_conf['sample_width']) # init vad - vad_conf = asr_engine.config.vad_conf - vad = VADAudio( - aggressiveness=vad_conf['aggressiveness'], - rate=vad_conf['sample_rate'], - frame_duration_ms=vad_conf['frame_duration_ms']) + # print(asr_engine.config) + # print(type(asr_engine.config)) + vad_conf = asr_engine.config.get('vad_conf', None) + if vad_conf: + vad = VADAudio( + aggressiveness=vad_conf['aggressiveness'], + rate=vad_conf['sample_rate'], + frame_duration_ms=vad_conf['frame_duration_ms']) try: while True: @@ -65,7 +70,7 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] # reset single engine for an new connection - asr_engine.reset() + # asr_engine.reset() resp = {"status": "ok", "signal": "finished"} await websocket.send_json(resp) break @@ -75,16 +80,16 @@ async def websocket_endpoint(websocket: WebSocket): elif "bytes" in message: message = message["bytes"] - # vad for input bytes audio - vad.add_audio(message) - message = b''.join(f for f in vad.vad_collector() - if f is not None) - + # # vad for input bytes audio + # vad.add_audio(message) + # message = b''.join(f for f in vad.vad_collector() + # if f is not None) engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] asr_results = "" frames = chunk_buffer.frame_generator(message) for frame in frames: + # get the pcm data from the bytes samples = np.frombuffer(frame.bytes, dtype=np.int16) sample_rate = asr_engine.config.sample_rate x_chunk, x_chunk_lens = asr_engine.preprocess(samples, From 0c5dbbee5bdb784e44d7f6ad1f7a7d911c833e06 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Sat, 16 Apr 2022 21:37:46 +0800 Subject: [PATCH 06/31] add conformer ctc prefix beam search decoding method, test=doc --- .../server/engine/asr/online/asr_engine.py | 213 +++++++++++++++--- paddlespeech/server/ws/asr_socket.py | 26 ++- 2 files changed, 195 insertions(+), 44 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index a5b9ab48..e1e4a7ad 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from paddlespeech.s2t.utils.utility import log_add from typing import Optional - +from collections import defaultdict import numpy as np import paddle from numpy import float32 @@ -23,10 +24,14 @@ from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import model_alias from paddlespeech.cli.asr.infer import pretrained_models from paddlespeech.cli.log import logger +from paddlespeech.cli.utils import download_and_decompress from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder +from paddlespeech.s2t.modules.mask import mask_finished_preds +from paddlespeech.s2t.modules.mask import mask_finished_scores +from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.utility import UpdateConfig @@ -57,17 +62,17 @@ pretrained_models = { }, "conformer2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr1_chunk_conformer_aishell_ckpt_0.1.2.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz', 'md5': - '4814e52e0fc2fd48899373f95c84b0c9', + '7989b3248c898070904cf042fd656003', 'cfg_path': - 'exp/chunk_conformer//conf/config.yaml', + 'model.yaml', 'ckpt_path': - 'exp/chunk_conformer/checkpoints/avg_30/', + 'exp/chunk_conformer/checkpoints/multi_cn', 'model': - 'exp/chunk_conformer/checkpoints/avg_30.pdparams', + 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', 'params': - 'exp/chunk_conformer/checkpoints/avg_30.pdparams', + 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', 'lm_url': 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', 'lm_md5': @@ -81,6 +86,23 @@ class ASRServerExecutor(ASRExecutor): super().__init__() pass + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + support_models = list(pretrained_models.keys()) + assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( + tag, '\n\t\t'.join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + + return decompressed_path + def _init_from_path(self, model_type: str='wenetspeech', am_model: Optional[os.PathLike]=None, @@ -101,7 +123,7 @@ class ASRServerExecutor(ASRExecutor): logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/paddlespeech/server/tests/asr/online/conf/config.yaml" + self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" # self.cfg_path = os.path.join(res_path, # pretrained_models[tag]['cfg_path']) @@ -147,8 +169,7 @@ class ASRServerExecutor(ASRExecutor): if self.config.spm_model_prefix: self.config.spm_model_prefix = os.path.join( self.res_path, self.config.spm_model_prefix) - self.config.vocab_filepath = os.path.join( - self.res_path, self.config.vocab_filepath) + self.vocab = self.config.vocab_filepath self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.config.vocab_filepath, @@ -203,19 +224,31 @@ class ASRServerExecutor(ASRExecutor): model_conf = self.config model = model_class.from_config(model_conf) self.model = model + self.model.eval() + + # load model + model_dict = paddle.load(self.am_model) + self.model.set_state_dict(model_dict) logger.info("create the transformer like model success") + # update the ctc decoding + self.searcher = None + self.transformer_decode_reset() + def reset_decoder_and_chunk(self): """reset decoder and chunk state for an new audio """ - self.decoder.reset_decoder(batch_size=1) - # init state box, for new audio request - self.chunk_state_h_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - self.chunk_state_c_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + self.decoder.reset_decoder(batch_size=1) + # init state box, for new audio request + self.chunk_state_h_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: + self.transformer_decode_reset() def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): """decode one chunk @@ -275,24 +308,137 @@ class ASRServerExecutor(ASRExecutor): logger.info( f"we will use the transformer like model : {self.model_type}" ) - cfg = self.config.decode - result_transcripts = self.model.decode( - x_chunk, - x_chunk_lens, - text_feature=self.text_feature, - decoding_method=cfg.decoding_method, - beam_size=cfg.beam_size, - ctc_weight=cfg.ctc_weight, - decoding_chunk_size=cfg.decoding_chunk_size, - num_decoding_left_chunks=cfg.num_decoding_left_chunks, - simulate_streaming=cfg.simulate_streaming) - - return result_transcripts[0][0] + self.advanced_decoding(x_chunk, x_chunk_lens) + self.update_result() + + return self.result_transcripts[0] except Exception as e: logger.exception(e) else: raise Exception("invalid model name") + def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): + logger.info("start to decode with advanced_decoding method") + encoder_out, encoder_mask = self.decode_forward(xs) + self.ctc_prefix_beam_search(xs, encoder_out, encoder_mask) + + def decode_forward(self, xs): + logger.info("get the model out from the feat") + cfg = self.config.decode + decoding_chunk_size = cfg.decoding_chunk_size + num_decoding_left_chunks = cfg.num_decoding_left_chunks + + assert decoding_chunk_size > 0 + subsampling = self.model.encoder.embed.subsampling_rate + context = self.model.encoder.embed.right_context + 1 + stride = subsampling * decoding_chunk_size + + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.shape[1] + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + + logger.info("start to do model forward") + outputs = [] + + # num_frames - context + 1 ensure that current frame can get context window + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + (y, self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + chunk_xs, self.offset, required_cache_size, + self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) + outputs.append(y) + self.offset += y.shape[1] + + ys = paddle.cat(outputs, 1) + masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) + masks = masks.unsqueeze(1) + return ys, masks + + def transformer_decode_reset(self): + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.hyps = None + self.offset = 0 + self.cur_hyps = None + self.hyps = None + + def ctc_prefix_beam_search(self, xs, encoder_out, encoder_mask, blank_id=0): + # decode + logger.info("start to ctc prefix search") + + device = xs.place + cfg = self.config.decode + batch_size = xs.shape[0] + beam_size = cfg.beam_size + maxlen = encoder_out.shape[1] + + ctc_probs = self.model.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + # blank_ending_score and none_blank_ending_score in ln domain + if self.cur_hyps is None: + self.cur_hyps = [(tuple(), (0.0, -float('inf')))] + # 2. CTC beam search step by step + for t in range(0, maxlen): + logp = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + + # 2.1 First beam prune: select topk best + # do token passing process + top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + for s in top_k_index: + s = s.item() + ps = logp[s].item() + for prefix, (pb, pnb) in self.cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == blank_id: # blank + n_pb, n_pnb = next_hyps[prefix] + n_pb = log_add([n_pb, pb + ps, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + elif s == last: + # Update *ss -> *s; + n_pb, n_pnb = next_hyps[prefix] + n_pnb = log_add([n_pnb, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + + # 2.2 Second beam prune + next_hyps = sorted( + next_hyps.items(), + key=lambda x: log_add(list(x[1])), + reverse=True) + self.cur_hyps = next_hyps[:beam_size] + + hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] + + self.hyps = [hyps[0][0]] + logger.info("ctc prefix search success") + return hyps, encoder_out + + def update_result(self): + logger.info("update the final result") + self.result_transcripts = [ + self.text_feature.defeaturize(hyp) for hyp in self.hyps + ] + self.result_tokenids = [hyp for hyp in self.hyps] + def extract_feat(self, samples, sample_rate): """extract feat @@ -304,9 +450,10 @@ class ASRServerExecutor(ASRExecutor): x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] """ - # pcm16 -> pcm 32 - samples = pcm2float(samples) + if "deepspeech2online" in self.model_type: + # pcm16 -> pcm 32 + samples = pcm2float(samples) # read audio speech_segment = SpeechSegment.from_pcm( samples, sample_rate, transcript=" ") diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 14254928..4d1013f4 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -14,7 +14,6 @@ import json import numpy as np -import json from fastapi import APIRouter from fastapi import WebSocket from fastapi import WebSocketDisconnect @@ -86,16 +85,21 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] asr_results = "" - frames = chunk_buffer.frame_generator(message) - for frame in frames: - # get the pcm data from the bytes - samples = np.frombuffer(frame.bytes, dtype=np.int16) - sample_rate = asr_engine.config.sample_rate - x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - sample_rate) - asr_engine.run(x_chunk, x_chunk_lens) - asr_results = asr_engine.postprocess() - + # frames = chunk_buffer.frame_generator(message) + # for frame in frames: + # # get the pcm data from the bytes + # samples = np.frombuffer(frame.bytes, dtype=np.int16) + # sample_rate = asr_engine.config.sample_rate + # x_chunk, x_chunk_lens = asr_engine.preprocess(samples, + # sample_rate) + # asr_engine.run(x_chunk, x_chunk_lens) + # asr_results = asr_engine.postprocess() + samples = np.frombuffer(message, dtype=np.int16) + sample_rate = asr_engine.config.sample_rate + x_chunk, x_chunk_lens = asr_engine.preprocess(samples, + sample_rate) + asr_engine.run(x_chunk, x_chunk_lens) + # asr_results = asr_engine.postprocess() asr_results = asr_engine.postprocess() resp = {'asr_results': asr_results} From 97d31f9aacc37e936d70f0a10bccf1622fd69323 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Sun, 17 Apr 2022 16:27:35 +0800 Subject: [PATCH 07/31] update the attention_rescoring method, test=doc --- paddlespeech/server/conf/ws_application.yaml | 16 +- .../server/engine/asr/online/asr_engine.py | 177 +++++++++--------- .../server/engine/asr/online/ctc_search.py | 119 ++++++++++++ paddlespeech/server/tests/__init__.py | 13 ++ paddlespeech/server/tests/asr/__init__.py | 13 ++ .../server/tests/asr/offline/__init__.py | 13 ++ .../server/tests/asr/online/__init__.py | 13 ++ paddlespeech/server/ws/asr_socket.py | 43 ++--- 8 files changed, 287 insertions(+), 120 deletions(-) create mode 100644 paddlespeech/server/engine/asr/online/ctc_search.py create mode 100644 paddlespeech/server/tests/__init__.py create mode 100644 paddlespeech/server/tests/asr/__init__.py create mode 100644 paddlespeech/server/tests/asr/offline/__init__.py create mode 100644 paddlespeech/server/tests/asr/online/__init__.py diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index c3a488fb..aa3c208b 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -71,15 +71,9 @@ asr_online: summary: True # False -> do not show predictor config chunk_buffer_conf: - frame_duration_ms: 85 - shift_ms: 40 + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 25 # ms + shift_ms: 10 # ms sample_rate: 16000 - sample_width: 2 - - # vad_conf: - # aggressiveness: 2 - # sample_rate: 16000 - # frame_duration_ms: 20 - # sample_width: 2 - # padding_ms: 200 - # padding_ratio: 0.9 \ No newline at end of file + sample_width: 2 \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index e1e4a7ad..e292f9cf 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from paddlespeech.s2t.utils.utility import log_add from typing import Optional -from collections import defaultdict + import numpy as np import paddle from numpy import float32 @@ -22,19 +21,18 @@ from yacs.config import CfgNode from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import model_alias -from paddlespeech.cli.asr.infer import pretrained_models from paddlespeech.cli.log import logger from paddlespeech.cli.utils import download_and_decompress from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.s2t.modules.mask import mask_finished_preds -from paddlespeech.s2t.modules.mask import mask_finished_scores -from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.s2t.utils.tensor_utils import add_sos_eos +from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor @@ -62,9 +60,9 @@ pretrained_models = { }, "conformer2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.1.model.tar.gz', 'md5': - '7989b3248c898070904cf042fd656003', + 'b450d5dfaea0ac227c595ce58d18b637', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -123,9 +121,9 @@ class ASRServerExecutor(ASRExecutor): logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" - # self.cfg_path = os.path.join(res_path, - # pretrained_models[tag]['cfg_path']) + # self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" + self.cfg_path = os.path.join(res_path, + pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, pretrained_models[tag]['model']) @@ -177,6 +175,18 @@ class ASRServerExecutor(ASRExecutor): # update the decoding method if decode_method: self.config.decode.decoding_method = decode_method + + # we only support ctc_prefix_beam_search and attention_rescoring dedoding method + # Generally we set the decoding_method to attention_rescoring + if self.config.decode.decoding_method not in [ + "ctc_prefix_beam_search", "attention_rescoring" + ]: + logger.info( + "we set the decoding_method to attention_rescoring") + self.config.decode.decoding = "attention_rescoring" + assert self.config.decode.decoding_method in [ + "ctc_prefix_beam_search", "attention_rescoring" + ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" else: raise Exception("wrong type") if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: @@ -232,7 +242,7 @@ class ASRServerExecutor(ASRExecutor): logger.info("create the transformer like model success") # update the ctc decoding - self.searcher = None + self.searcher = CTCPrefixBeamSearch(self.config.decode) self.transformer_decode_reset() def reset_decoder_and_chunk(self): @@ -320,7 +330,16 @@ class ASRServerExecutor(ASRExecutor): def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): logger.info("start to decode with advanced_decoding method") encoder_out, encoder_mask = self.decode_forward(xs) - self.ctc_prefix_beam_search(xs, encoder_out, encoder_mask) + ctc_probs = self.model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + self.searcher.search(xs, ctc_probs, xs.place) + # update the one best result + self.hyps = self.searcher.get_one_best_hyps() + + # now we supprot ctc_prefix_beam_search and attention_rescoring + if "attention_rescoring" in self.config.decode.decoding_method: + self.rescoring(encoder_out, xs.place) def decode_forward(self, xs): logger.info("get the model out from the feat") @@ -338,7 +357,6 @@ class ASRServerExecutor(ASRExecutor): num_frames = xs.shape[1] required_cache_size = decoding_chunk_size * num_decoding_left_chunks - logger.info("start to do model forward") outputs = [] @@ -359,85 +377,74 @@ class ASRServerExecutor(ASRExecutor): masks = masks.unsqueeze(1) return ys, masks + def rescoring(self, encoder_out, device): + logger.info("start to rescoring the hyps") + beam_size = self.config.decode.beam_size + hyps = self.searcher.get_hyps() + assert len(hyps) == beam_size + + hyp_list = [] + for hyp in hyps: + hyp_content = hyp[0] + # Prevent the hyp is empty + if len(hyp_content) == 0: + hyp_content = (self.model.ctc.blank_id, ) + hyp_content = paddle.to_tensor( + hyp_content, place=device, dtype=paddle.long) + hyp_list.append(hyp_content) + hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) + hyps_lens = paddle.to_tensor( + [len(hyp[0]) for hyp in hyps], place=device, + dtype=paddle.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, + self.model.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + + encoder_out = encoder_out.repeat(beam_size, 1, 1) + encoder_mask = paddle.ones( + (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) + decoder_out, _ = self.model.decoder( + encoder_out, encoder_mask, hyps_pad, + hyps_lens) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain + decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) + decoder_out = decoder_out.numpy() + + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. + score += decoder_out[i][len(hyp[0])][self.model.eos] + # add ctc score (which in ln domain) + score += hyp[1] * self.config.decode.ctc_weight + if score > best_score: + best_score = score + best_index = i + + # update the one best result + self.hyps = [hyps[best_index][0]] + return hyps[best_index][0] + def transformer_decode_reset(self): self.subsampling_cache = None self.elayers_output_cache = None self.conformer_cnn_cache = None - self.hyps = None self.offset = 0 - self.cur_hyps = None - self.hyps = None - - def ctc_prefix_beam_search(self, xs, encoder_out, encoder_mask, blank_id=0): - # decode - logger.info("start to ctc prefix search") - - device = xs.place - cfg = self.config.decode - batch_size = xs.shape[0] - beam_size = cfg.beam_size - maxlen = encoder_out.shape[1] - - ctc_probs = self.model.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size) - ctc_probs = ctc_probs.squeeze(0) - - # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) - # blank_ending_score and none_blank_ending_score in ln domain - if self.cur_hyps is None: - self.cur_hyps = [(tuple(), (0.0, -float('inf')))] - # 2. CTC beam search step by step - for t in range(0, maxlen): - logp = ctc_probs[t] # (vocab_size,) - # key: prefix, value (pb, pnb), default value(-inf, -inf) - next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) - - # 2.1 First beam prune: select topk best - # do token passing process - top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) - for s in top_k_index: - s = s.item() - ps = logp[s].item() - for prefix, (pb, pnb) in self.cur_hyps: - last = prefix[-1] if len(prefix) > 0 else None - if s == blank_id: # blank - n_pb, n_pnb = next_hyps[prefix] - n_pb = log_add([n_pb, pb + ps, pnb + ps]) - next_hyps[prefix] = (n_pb, n_pnb) - elif s == last: - # Update *ss -> *s; - n_pb, n_pnb = next_hyps[prefix] - n_pnb = log_add([n_pnb, pnb + ps]) - next_hyps[prefix] = (n_pb, n_pnb) - # Update *s-s -> *ss, - is for blank - n_prefix = prefix + (s, ) - n_pb, n_pnb = next_hyps[n_prefix] - n_pnb = log_add([n_pnb, pb + ps]) - next_hyps[n_prefix] = (n_pb, n_pnb) - else: - n_prefix = prefix + (s, ) - n_pb, n_pnb = next_hyps[n_prefix] - n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) - next_hyps[n_prefix] = (n_pb, n_pnb) - - # 2.2 Second beam prune - next_hyps = sorted( - next_hyps.items(), - key=lambda x: log_add(list(x[1])), - reverse=True) - self.cur_hyps = next_hyps[:beam_size] - - hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] - - self.hyps = [hyps[0][0]] - logger.info("ctc prefix search success") - return hyps, encoder_out + # decoding reset + self.searcher.reset() def update_result(self): logger.info("update the final result") + hyps = self.hyps self.result_transcripts = [ - self.text_feature.defeaturize(hyp) for hyp in self.hyps + self.text_feature.defeaturize(hyp) for hyp in hyps ] - self.result_tokenids = [hyp for hyp in self.hyps] + self.result_tokenids = [hyp for hyp in hyps] def extract_feat(self, samples, sample_rate): """extract feat @@ -483,9 +490,9 @@ class ASRServerExecutor(ASRExecutor): elif "conformer2online" in self.model_type: if sample_rate != self.sample_rate: - logger.info(f"audio sample rate {sample_rate} is not match," \ + logger.info(f"audio sample rate {sample_rate} is not match," "the model sample_rate is {self.sample_rate}") - logger.info(f"ASR Engine use the {self.model_type} to process") + logger.info("ASR Engine use the {self.model_type} to process") logger.info("Create the preprocess instance") preprocess_conf = self.config.preprocess_config preprocess_args = {"train": False} diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py new file mode 100644 index 00000000..a91b8a21 --- /dev/null +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -0,0 +1,119 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict + +from paddlespeech.cli.log import logger +from paddlespeech.s2t.utils.utility import log_add + +__all__ = ['CTCPrefixBeamSearch'] + + +class CTCPrefixBeamSearch: + def __init__(self, config): + """Implement the ctc prefix beam search + + Args: + config (_type_): _description_ + """ + self.config = config + self.reset() + + def search(self, xs, ctc_probs, device, blank_id=0): + """ctc prefix beam search method decode a chunk feature + + Args: + xs (paddle.Tensor): feature data + ctc_probs (paddle.Tensor): the ctc probability of all the tokens + encoder_out (paddle.Tensor): _description_ + encoder_mask (_type_): _description_ + blank_id (int, optional): the blank id in the vocab. Defaults to 0. + + Returns: + list: the search result + """ + # decode + logger.info("start to ctc prefix search") + + # device = xs.place + batch_size = xs.shape[0] + beam_size = self.config.beam_size + maxlen = ctc_probs.shape[0] + + assert len(ctc_probs.shape) == 2 + + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + # blank_ending_score and none_blank_ending_score in ln domain + if self.cur_hyps is None: + self.cur_hyps = [(tuple(), (0.0, -float('inf')))] + # 2. CTC beam search step by step + for t in range(0, maxlen): + logp = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + + # 2.1 First beam prune: select topk best + # do token passing process + top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + for s in top_k_index: + s = s.item() + ps = logp[s].item() + for prefix, (pb, pnb) in self.cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == blank_id: # blank + n_pb, n_pnb = next_hyps[prefix] + n_pb = log_add([n_pb, pb + ps, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + elif s == last: + # Update *ss -> *s; + n_pb, n_pnb = next_hyps[prefix] + n_pnb = log_add([n_pnb, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + + # 2.2 Second beam prune + next_hyps = sorted( + next_hyps.items(), + key=lambda x: log_add(list(x[1])), + reverse=True) + self.cur_hyps = next_hyps[:beam_size] + + self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] + logger.info("ctc prefix search success") + return self.hyps + + def get_one_best_hyps(self): + """Return the one best result + + Returns: + list: the one best result + """ + return [self.hyps[0][0]] + + def get_hyps(self): + return self.hyps + + def reset(self): + """Rest the search cache value + """ + self.cur_hyps = None + self.hyps = None diff --git a/paddlespeech/server/tests/__init__.py b/paddlespeech/server/tests/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/__init__.py b/paddlespeech/server/tests/asr/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/offline/__init__.py b/paddlespeech/server/tests/asr/offline/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/offline/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/online/__init__.py b/paddlespeech/server/tests/asr/online/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/online/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 4d1013f4..87b43d2c 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -34,17 +34,17 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] # init buffer + # each websocekt connection has its own chunk buffer chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer = ChunkBuffer( - window_n=7, - shift_n=4, - window_ms=20, - shift_ms=10, - sample_rate=chunk_buffer_conf['sample_rate'], - sample_width=chunk_buffer_conf['sample_width']) + window_n=chunk_buffer_conf.window_n, + shift_n=chunk_buffer_conf.shift_n, + window_ms=chunk_buffer_conf.window_ms, + shift_ms=chunk_buffer_conf.shift_ms, + sample_rate=chunk_buffer_conf.sample_rate, + sample_width=chunk_buffer_conf.sample_width) + # init vad - # print(asr_engine.config) - # print(type(asr_engine.config)) vad_conf = asr_engine.config.get('vad_conf', None) if vad_conf: vad = VADAudio( @@ -72,7 +72,7 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] # reset single engine for an new connection - # asr_engine.reset() + asr_engine.reset() resp = {"status": "ok", "signal": "finished"} await websocket.send_json(resp) break @@ -85,21 +85,16 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] asr_results = "" - # frames = chunk_buffer.frame_generator(message) - # for frame in frames: - # # get the pcm data from the bytes - # samples = np.frombuffer(frame.bytes, dtype=np.int16) - # sample_rate = asr_engine.config.sample_rate - # x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - # sample_rate) - # asr_engine.run(x_chunk, x_chunk_lens) - # asr_results = asr_engine.postprocess() - samples = np.frombuffer(message, dtype=np.int16) - sample_rate = asr_engine.config.sample_rate - x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - sample_rate) - asr_engine.run(x_chunk, x_chunk_lens) - # asr_results = asr_engine.postprocess() + frames = chunk_buffer.frame_generator(message) + for frame in frames: + # get the pcm data from the bytes + samples = np.frombuffer(frame.bytes, dtype=np.int16) + sample_rate = asr_engine.config.sample_rate + x_chunk, x_chunk_lens = asr_engine.preprocess(samples, + sample_rate) + asr_engine.run(x_chunk, x_chunk_lens) + asr_results = asr_engine.postprocess() + asr_results = asr_engine.postprocess() resp = {'asr_results': asr_results} From d2640c14064058c5283830fd2046d1788e800046 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 18 Apr 2022 12:58:40 +0800 Subject: [PATCH 08/31] add mult sesssion process, test=doc --- .../server/engine/asr/online/asr_engine.py | 190 +++++++++++++++++- 1 file changed, 189 insertions(+), 1 deletion(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index e292f9cf..3546e598 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -78,6 +78,194 @@ pretrained_models = { }, } +# ASR server connection process class + +class PaddleASRConnectionHanddler: + def __init__(self, asr_engine): + super().__init__() + self.config = asr_engine.config + self.model_config = asr_engine.executor.config + self.asr_engine = asr_engine + + self.init() + self.reset() + + def init(self): + self.model_type = self.asr_engine.executor.model_type + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + pass + elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: + self.sample_rate = self.asr_engine.executor.sample_rate + + # acoustic model + self.model = self.asr_engine.executor.model + + # tokens to text + self.text_feature = self.asr_engine.executor.text_feature + + # ctc decoding + self.ctc_decode_config = self.asr_engine.executor.config.decode + self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) + + # extract fbank + self.preprocess_conf = self.model_config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + self.win_length = self.preprocess_conf.process[0]['win_length'] + self.n_shift = self.preprocess_conf.process[0]['n_shift'] + + def extract_feat(self, samples): + if "deepspeech2online" in self.model_type: + pass + elif "conformer2online" in self.model_type: + logger.info("Online ASR extract the feat") + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 + + logger.info(f"This package receive {samples.shape[0]} pcm data") + self.num_samples += samples.shape[0] + + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The connection remain the audio samples: {self.remained_wav.shape}" + ) + if len(self.remained_wav) < self.win_length: + return 0 + + # fbank + x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) + x_chunk = paddle.to_tensor( + x_chunk, dtype="float32").unsqueeze(axis=0) + if self.cached_feat is None: + self.cached_feat = x_chunk + else: + self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1) + + num_frames = x_chunk.shape[1] + self.num_frames += num_frames + self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + + logger.info( + f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" + ) + # logger.info(f"accumulate samples: {self.num_samples}") + + def reset(self): + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_outs_ = None + self.cached_feat = None + self.remained_wav = None + self.offset = 0 + self.num_samples = 0 + + self.num_frames = 0 + self.global_frame_offset = 0 + self.result = [] + + def decode(self, is_finished=False): + if "deepspeech2online" in self.model_type: + pass + elif "conformer" in self.model_type or "transformer" in self.model_type: + try: + logger.info( + f"we will use the transformer like model : {self.model_type}" + ) + self.advance_decoding(is_finished) + # self.update_result() + + # return self.result_transcripts[0] + except Exception as e: + logger.exception(e) + else: + raise Exception("invalid model name") + + def advance_decoding(self, is_finished=False): + logger.info("start to decode with advanced_decoding method") + cfg = self.ctc_decode_config + decoding_chunk_size = cfg.decoding_chunk_size + num_decoding_left_chunks = cfg.num_decoding_left_chunks + + assert decoding_chunk_size > 0 + subsampling = self.model.encoder.embed.subsampling_rate + context = self.model.encoder.embed.right_context + 1 + stride = subsampling * decoding_chunk_size + + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = self.cached_feat.shape[1] + logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames") + + # the cached feat must be larger decoding_window + if num_frames < decoding_window and not is_finished: + return None, None + + # logger.info("start to do model forward") + # required_cache_size = decoding_chunk_size * num_decoding_left_chunks + # outputs = [] + + # # num_frames - context + 1 ensure that current frame can get context window + # if is_finished: + # # if get the finished chunk, we need process the last context + # left_frames = context + # else: + # # we only process decoding_window frames for one chunk + # left_frames = decoding_window + + # logger.info(f"") + # end = None + # for cur in range(0, num_frames - left_frames + 1, stride): + # end = min(cur + decoding_window, num_frames) + # print(f"cur: {cur}, end: {end}") + # chunk_xs = self.cached_feat[:, cur:end, :] + # (y, self.subsampling_cache, self.elayers_output_cache, + # self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + # chunk_xs, self.offset, required_cache_size, + # self.subsampling_cache, self.elayers_output_cache, + # self.conformer_cnn_cache) + # outputs.append(y) + # update the offset + # self.offset += y.shape[1] + # self.cached_feat = self.cached_feat[end:] + # ys = paddle.cat(outputs, 1) + # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) + # masks = masks.unsqueeze(1) + + # # get the ctc probs + # ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) + # ctc_probs = ctc_probs.squeeze(0) + # # self.searcher.search(xs, ctc_probs, xs.place) + + # self.searcher.search(None, ctc_probs, self.cached_feat.place) + + # self.hyps = self.searcher.get_one_best_hyps() + + # ys for rescoring + # return ys, masks + + def update_result(self): + logger.info("update the final result") + hyps = self.hyps + self.result_transcripts = [ + self.text_feature.defeaturize(hyp) for hyp in hyps + ] + self.result_tokenids = [hyp for hyp in hyps] + + def rescoring(self): + pass + + + class ASRServerExecutor(ASRExecutor): def __init__(self): @@ -492,7 +680,7 @@ class ASRServerExecutor(ASRExecutor): if sample_rate != self.sample_rate: logger.info(f"audio sample rate {sample_rate} is not match," "the model sample_rate is {self.sample_rate}") - logger.info("ASR Engine use the {self.model_type} to process") + logger.info(f"ASR Engine use the {self.model_type} to process") logger.info("Create the preprocess instance") preprocess_conf = self.config.preprocess_config preprocess_args = {"train": False} From 10e825d9b2f619b0c8525c3c24491a657ccc9269 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 18 Apr 2022 15:05:48 +0800 Subject: [PATCH 09/31] check chunk window process, test=doc --- .../server/engine/asr/online/asr_engine.py | 49 ++++++++++++------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 3546e598..1f6060e9 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -145,6 +145,8 @@ class PaddleASRConnectionHanddler: if self.cached_feat is None: self.cached_feat = x_chunk else: + assert(len(x_chunk.shape) == 3) + assert(len(self.cached_feat.shape) == 3) self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1) num_frames = x_chunk.shape[1] @@ -170,6 +172,7 @@ class PaddleASRConnectionHanddler: self.num_samples = 0 self.num_frames = 0 + self.chunk_num = 0 self.global_frame_offset = 0 self.result = [] @@ -210,23 +213,24 @@ class PaddleASRConnectionHanddler: if num_frames < decoding_window and not is_finished: return None, None - # logger.info("start to do model forward") - # required_cache_size = decoding_chunk_size * num_decoding_left_chunks - # outputs = [] - - # # num_frames - context + 1 ensure that current frame can get context window - # if is_finished: - # # if get the finished chunk, we need process the last context - # left_frames = context - # else: - # # we only process decoding_window frames for one chunk - # left_frames = decoding_window + logger.info("start to do model forward") + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + outputs = [] + + # num_frames - context + 1 ensure that current frame can get context window + if is_finished: + # if get the finished chunk, we need process the last context + left_frames = context + else: + # we only process decoding_window frames for one chunk + left_frames = decoding_window # logger.info(f"") - # end = None - # for cur in range(0, num_frames - left_frames + 1, stride): - # end = min(cur + decoding_window, num_frames) - # print(f"cur: {cur}, end: {end}") + end = None + for cur in range(0, num_frames - left_frames + 1, stride): + end = min(cur + decoding_window, num_frames) + print(f"cur chunk: {self.chunk_num}, cur: {cur}, end: {end}") + self.chunk_num += 1 # chunk_xs = self.cached_feat[:, cur:end, :] # (y, self.subsampling_cache, self.elayers_output_cache, # self.conformer_cnn_cache) = self.model.encoder.forward_chunk( @@ -236,7 +240,14 @@ class PaddleASRConnectionHanddler: # outputs.append(y) # update the offset # self.offset += y.shape[1] - # self.cached_feat = self.cached_feat[end:] + + # remove the processed feat + if end == num_frames: + self.cached_feat = None + else: + assert self.cached_feat.shape[0] == 1 + self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0) + assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" # ys = paddle.cat(outputs, 1) # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) # masks = masks.unsqueeze(1) @@ -309,9 +320,9 @@ class ASRServerExecutor(ASRExecutor): logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - # self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" - self.cfg_path = os.path.join(res_path, - pretrained_models[tag]['cfg_path']) + self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" + # self.cfg_path = os.path.join(res_path, + # pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, pretrained_models[tag]['model']) From 68731c61f40d2dc5eb154c5b9cd3faa8f0efd672 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 18 Apr 2022 15:17:29 +0800 Subject: [PATCH 10/31] add multi session result, test=doc --- .../server/engine/asr/online/asr_engine.py | 50 ++++++++++--------- .../server/engine/asr/online/ctc_search.py | 2 +- paddlespeech/server/utils/buffer.py | 2 +- paddlespeech/server/ws/asr_socket.py | 36 ++++++++----- 4 files changed, 51 insertions(+), 39 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 1f6060e9..c13b2f6d 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -185,9 +185,9 @@ class PaddleASRConnectionHanddler: f"we will use the transformer like model : {self.model_type}" ) self.advance_decoding(is_finished) - # self.update_result() + self.update_result() - # return self.result_transcripts[0] + return self.result_transcripts[0] except Exception as e: logger.exception(e) else: @@ -225,22 +225,36 @@ class PaddleASRConnectionHanddler: # we only process decoding_window frames for one chunk left_frames = decoding_window - # logger.info(f"") + # record the end for removing the processed feat end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) - print(f"cur chunk: {self.chunk_num}, cur: {cur}, end: {end}") + self.chunk_num += 1 - # chunk_xs = self.cached_feat[:, cur:end, :] - # (y, self.subsampling_cache, self.elayers_output_cache, - # self.conformer_cnn_cache) = self.model.encoder.forward_chunk( - # chunk_xs, self.offset, required_cache_size, - # self.subsampling_cache, self.elayers_output_cache, - # self.conformer_cnn_cache) - # outputs.append(y) + chunk_xs = self.cached_feat[:, cur:end, :] + (y, self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + chunk_xs, self.offset, required_cache_size, + self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) + outputs.append(y) + # update the offset - # self.offset += y.shape[1] + self.offset += y.shape[1] + ys = paddle.cat(outputs, 1) + masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) + masks = masks.unsqueeze(1) + + # get the ctc probs + ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + # self.searcher.search(xs, ctc_probs, xs.place) + + self.searcher.search(None, ctc_probs, self.cached_feat.place) + + self.hyps = self.searcher.get_one_best_hyps() + # remove the processed feat if end == num_frames: self.cached_feat = None @@ -248,19 +262,7 @@ class PaddleASRConnectionHanddler: assert self.cached_feat.shape[0] == 1 self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0) assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" - # ys = paddle.cat(outputs, 1) - # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) - # masks = masks.unsqueeze(1) - - # # get the ctc probs - # ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) - # ctc_probs = ctc_probs.squeeze(0) - # # self.searcher.search(xs, ctc_probs, xs.place) - - # self.searcher.search(None, ctc_probs, self.cached_feat.place) - # self.hyps = self.searcher.get_one_best_hyps() - # ys for rescoring # return ys, masks diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index a91b8a21..bf4c4b30 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -46,7 +46,7 @@ class CTCPrefixBeamSearch: logger.info("start to ctc prefix search") # device = xs.place - batch_size = xs.shape[0] + batch_size = 1 beam_size = self.config.beam_size maxlen = ctc_probs.shape[0] diff --git a/paddlespeech/server/utils/buffer.py b/paddlespeech/server/utils/buffer.py index 12b1f0e5..d4e6cd49 100644 --- a/paddlespeech/server/utils/buffer.py +++ b/paddlespeech/server/utils/buffer.py @@ -63,12 +63,12 @@ class ChunkBuffer(object): the sample rate. Yields Frames of the requested duration. """ + audio = self.remained_audio + audio self.remained_audio = b'' offset = 0 timestamp = 0.0 - while offset + self.window_bytes <= len(audio): yield Frame(audio[offset:offset + self.window_bytes], timestamp, self.window_sec) diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 87b43d2c..04807e5c 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -22,6 +22,7 @@ from starlette.websockets import WebSocketState as WebSocketState from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.vad import VADAudio +from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler router = APIRouter() @@ -33,6 +34,7 @@ async def websocket_endpoint(websocket: WebSocket): engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] + connection_handler = None # init buffer # each websocekt connection has its own chunk buffer chunk_buffer_conf = asr_engine.config.chunk_buffer_conf @@ -67,13 +69,17 @@ async def websocket_endpoint(websocket: WebSocket): if message['signal'] == 'start': resp = {"status": "ok", "signal": "server_ready"} # do something at begining here + # create the instance to process the audio + connection_handler = PaddleASRConnectionHanddler(asr_engine) await websocket.send_json(resp) elif message['signal'] == 'end': engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] # reset single engine for an new connection + asr_results = connection_handler.decode(is_finished=True) + connection_handler.reset() asr_engine.reset() - resp = {"status": "ok", "signal": "finished"} + resp = {"status": "ok", "signal": "finished", 'asr_results': asr_results} await websocket.send_json(resp) break else: @@ -81,23 +87,27 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_json(resp) elif "bytes" in message: message = message["bytes"] - engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] asr_results = "" - frames = chunk_buffer.frame_generator(message) - for frame in frames: - # get the pcm data from the bytes - samples = np.frombuffer(frame.bytes, dtype=np.int16) - sample_rate = asr_engine.config.sample_rate - x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - sample_rate) - asr_engine.run(x_chunk, x_chunk_lens) - asr_results = asr_engine.postprocess() + connection_handler.extract_feat(message) + asr_results = connection_handler.decode(is_finished=False) + # connection_handler. + # frames = chunk_buffer.frame_generator(message) + # for frame in frames: + # # get the pcm data from the bytes + # samples = np.frombuffer(frame.bytes, dtype=np.int16) + # sample_rate = asr_engine.config.sample_rate + # x_chunk, x_chunk_lens = asr_engine.preprocess(samples, + # sample_rate) + # asr_engine.run(x_chunk, x_chunk_lens) + # asr_results = asr_engine.postprocess() - asr_results = asr_engine.postprocess() + # # connection accept the sample data frame by frame + + # asr_results = asr_engine.postprocess() resp = {'asr_results': asr_results} - + print("\n") await websocket.send_json(resp) except WebSocketDisconnect: pass From 9c0ceaacb6aafa1175b0df7372fb411e2fd772fe Mon Sep 17 00:00:00 2001 From: lym0302 Date: Mon, 18 Apr 2022 17:27:45 +0800 Subject: [PATCH 11/31] add streaming am infer, test=doc --- .../server/engine/tts/online/tts_engine.py | 517 ++++++++++++++++-- paddlespeech/server/utils/util.py | 4 + 2 files changed, 462 insertions(+), 59 deletions(-) diff --git a/paddlespeech/server/engine/tts/online/tts_engine.py b/paddlespeech/server/engine/tts/online/tts_engine.py index 25a8bc76..8e76225d 100644 --- a/paddlespeech/server/engine/tts/online/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/tts_engine.py @@ -12,24 +12,322 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 +import math +import os import time +from typing import Optional import numpy as np import paddle +import yaml +from yacs.config import CfgNode from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import float2pcm +from paddlespeech.server.utils.util import denorm from paddlespeech.server.utils.util import get_chunks +from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.zh_frontend import Frontend +from paddlespeech.t2s.modules.normalizer import ZScore + +__all__ = ['TTSEngine'] + +# support online model +pretrained_models = { + # fastspeech2 + "fastspeech2_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', + 'md5': + '637d28a5e53aa60275612ba4393d5f22', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_76000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + "fastspeech2_cnndecoder_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip', + 'md5': + '6eb28e22ace73e0ebe7845f86478f89f', + 'config': + 'cnndecoder.yaml', + 'ckpt': + 'snapshot_iter_153000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + + # mb_melgan + "mb_melgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'ee5f0604e20091f0d495b6ec4618b90d', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1000000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + + # hifigan + "hifigan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'dd40a3d88dfcf64513fba2f0f961ada6', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, +} + +model_alias = { + # acoustic model + "fastspeech2": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2", + "fastspeech2_inference": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", + + # voc + "mb_melgan": + "paddlespeech.t2s.models.melgan:MelGANGenerator", + "mb_melgan_inference": + "paddlespeech.t2s.models.melgan:MelGANInference", + "hifigan": + "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", + "hifigan_inference": + "paddlespeech.t2s.models.hifigan:HiFiGANInference", +} __all__ = ['TTSEngine'] class TTSServerExecutor(TTSExecutor): - def __init__(self): + def __init__(self, am_block, am_pad, voc_block, voc_pad): super().__init__() - pass + self.am_block = am_block + self.am_pad = am_pad + self.voc_block = voc_block + self.voc_pad = voc_pad + + def get_model_info(self, step, model_name, ckpt, stat): + """get model information + + Args: + step (string): am or voc + model_name (string): model type, support fastspeech2, higigan, mb_melgan + ckpt (string): ckpt file + stat (string): stat file, including mean and standard deviation + + Returns: + model, model_mu, model_std + """ + model_class = dynamic_import(model_name, model_alias) + + if step == "am": + odim = self.am_config.n_mels + model = model_class( + idim=self.vocab_size, odim=odim, **self.am_config["model"]) + model.set_state_dict(paddle.load(ckpt)["main_params"]) + + elif step == "voc": + model = model_class(**self.voc_config["generator_params"]) + model.set_state_dict(paddle.load(ckpt)["generator_params"]) + model.remove_weight_norm() + + else: + logger.error("Please set correct step, am or voc") + + model.eval() + model_mu, model_std = np.load(stat) + model_mu = paddle.to_tensor(model_mu) + model_std = paddle.to_tensor(model_std) + + return model, model_mu, model_std + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + support_models = list(pretrained_models.keys()) + assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( + tag, '\n\t\t'.join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + return decompressed_path + + def _init_from_path( + self, + am: str='fastspeech2_csmsc', + am_config: Optional[os.PathLike]=None, + am_ckpt: Optional[os.PathLike]=None, + am_stat: Optional[os.PathLike]=None, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, + voc: str='mb_melgan_csmsc', + voc_config: Optional[os.PathLike]=None, + voc_ckpt: Optional[os.PathLike]=None, + voc_stat: Optional[os.PathLike]=None, + lang: str='zh', ): + """ + Init model and other resources from a specific path. + """ + if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'): + logger.info('Models had been initialized.') + return + # am model info + am_tag = am + '-' + lang + if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: + am_res_path = self._get_pretrained_path(am_tag) + self.am_res_path = am_res_path + self.am_config = os.path.join(am_res_path, + pretrained_models[am_tag]['config']) + self.am_ckpt = os.path.join(am_res_path, + pretrained_models[am_tag]['ckpt']) + self.am_stat = os.path.join( + am_res_path, pretrained_models[am_tag]['speech_stats']) + # must have phones_dict in acoustic + self.phones_dict = os.path.join( + am_res_path, pretrained_models[am_tag]['phones_dict']) + print("self.phones_dict:", self.phones_dict) + logger.info(am_res_path) + logger.info(self.am_config) + logger.info(self.am_ckpt) + else: + self.am_config = os.path.abspath(am_config) + self.am_ckpt = os.path.abspath(am_ckpt) + self.am_stat = os.path.abspath(am_stat) + self.phones_dict = os.path.abspath(phones_dict) + self.am_res_path = os.path.dirname(os.path.abspath(self.am_config)) + print("self.phones_dict:", self.phones_dict) + + self.tones_dict = None + self.speaker_dict = None + + # voc model info + voc_tag = voc + '-' + lang + if voc_ckpt is None or voc_config is None or voc_stat is None: + voc_res_path = self._get_pretrained_path(voc_tag) + self.voc_res_path = voc_res_path + self.voc_config = os.path.join(voc_res_path, + pretrained_models[voc_tag]['config']) + self.voc_ckpt = os.path.join(voc_res_path, + pretrained_models[voc_tag]['ckpt']) + self.voc_stat = os.path.join( + voc_res_path, pretrained_models[voc_tag]['speech_stats']) + logger.info(voc_res_path) + logger.info(self.voc_config) + logger.info(self.voc_ckpt) + else: + self.voc_config = os.path.abspath(voc_config) + self.voc_ckpt = os.path.abspath(voc_ckpt) + self.voc_stat = os.path.abspath(voc_stat) + self.voc_res_path = os.path.dirname( + os.path.abspath(self.voc_config)) + + # Init body. + with open(self.am_config) as f: + self.am_config = CfgNode(yaml.safe_load(f)) + with open(self.voc_config) as f: + self.voc_config = CfgNode(yaml.safe_load(f)) + + with open(self.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + self.vocab_size = len(phn_id) + print("vocab_size:", self.vocab_size) + + # frontend + if lang == 'zh': + self.frontend = Frontend( + phone_vocab_path=self.phones_dict, + tone_vocab_path=self.tones_dict) + + elif lang == 'en': + self.frontend = English(phone_vocab_path=self.phones_dict) + print("frontend done!") + + # am infer info + self.am_name = am[:am.rindex('_')] + if self.am_name == "fastspeech2_cnndecoder": + self.am_inference, self.am_mu, self.am_std = self.get_model_info( + "am", "fastspeech2", self.am_ckpt, self.am_stat) + else: + am, am_mu, am_std = self.get_model_info("am", self.am_name, + self.am_ckpt, self.am_stat) + am_normalizer = ZScore(am_mu, am_std) + am_inference_class = dynamic_import(self.am_name + '_inference', + model_alias) + self.am_inference = am_inference_class(am_normalizer, am) + self.am_inference.eval() + print("acoustic model done!") + + # voc infer info + self.voc_name = voc[:voc.rindex('_')] + voc, voc_mu, voc_std = self.get_model_info("voc", self.voc_name, + self.voc_ckpt, self.voc_stat) + voc_normalizer = ZScore(voc_mu, voc_std) + voc_inference_class = dynamic_import(self.voc_name + '_inference', + model_alias) + self.voc_inference = voc_inference_class(voc_normalizer, voc) + self.voc_inference.eval() + print("voc done!") + + def get_phone(self, sentence, lang, merge_sentences, get_tone_ids): + tone_ids = None + if lang == 'zh': + input_ids = self.frontend.get_input_ids( + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + elif lang == 'en': + input_ids = self.frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + + def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): + """ + Streaming inference removes the result of pad inference + """ + front_pad = min(chunk_id * block, pad) + # first chunk + if chunk_id == 0: + data = data[:block * upsample] + # last chunk + elif chunk_id == chunk_num - 1: + data = data[front_pad * upsample:] + # middle chunk + else: + data = data[front_pad * upsample:(front_pad + block) * upsample] + + return data @paddle.no_grad() def infer( @@ -37,16 +335,19 @@ class TTSServerExecutor(TTSExecutor): text: str, lang: str='zh', am: str='fastspeech2_csmsc', - spk_id: int=0, - am_block: int=42, - am_pad: int=12, - voc_block: int=14, - voc_pad: int=14, ): + spk_id: int=0, ): """ Model inference and result stored in self.output. """ - am_name = am[:am.rindex('_')] - am_dataset = am[am.rindex('_') + 1:] + + am_block = self.am_block + am_pad = self.am_pad + am_upsample = 1 + voc_block = self.voc_block + voc_pad = self.voc_pad + voc_upsample = self.voc_config.n_shift + flag = 1 + get_tone_ids = False merge_sentences = False frontend_st = time.time() @@ -64,43 +365,99 @@ class TTSServerExecutor(TTSExecutor): phone_ids = input_ids["phone_ids"] else: print("lang should in {'zh', 'en'}!") - self.frontend_time = time.time() - frontend_st + frontend_et = time.time() + self.frontend_time = frontend_et - frontend_st for i in range(len(phone_ids)): - am_st = time.time() part_phone_ids = phone_ids[i] - # am - if am_name == 'speedyspeech': - part_tone_ids = tone_ids[i] - mel = self.am_inference(part_phone_ids, part_tone_ids) - # fastspeech2 + voc_chunk_id = 0 + + # fastspeech2_csmsc + if am == "fastspeech2_csmsc": + # am + mel = self.am_inference(part_phone_ids) + if flag == 1: + first_am_et = time.time() + self.first_am_infer = first_am_et - frontend_et + + # voc streaming + mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") + voc_chunk_num = len(mel_chunks) + voc_st = time.time() + for i, mel_chunk in enumerate(mel_chunks): + sub_wav = self.voc_inference(mel_chunk) + sub_wav = self.depadding(sub_wav, voc_chunk_num, i, + voc_block, voc_pad, voc_upsample) + if flag == 1: + first_voc_et = time.time() + self.first_voc_infer = first_voc_et - first_am_et + self.first_response_time = first_voc_et - frontend_st + flag = 0 + + yield sub_wav + + # fastspeech2_cnndecoder_csmsc + elif am == "fastspeech2_cnndecoder_csmsc": + # am + orig_hs, h_masks = self.am_inference.encoder_infer( + part_phone_ids) + + # streaming voc chunk info + mel_len = orig_hs.shape[1] + voc_chunk_num = math.ceil(mel_len / self.voc_block) + start = 0 + end = min(self.voc_block + self.voc_pad, mel_len) + + # streaming am + hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am") + am_chunk_num = len(hss) + for i, hs in enumerate(hss): + before_outs, _ = self.am_inference.decoder(hs) + after_outs = before_outs + self.am_inference.postnet( + before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + normalized_mel = after_outs[0] + sub_mel = denorm(normalized_mel, self.am_mu, self.am_std) + sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block, + am_pad, am_upsample) + + if i == 0: + mel_streaming = sub_mel + else: + mel_streaming = np.concatenate( + (mel_streaming, sub_mel), axis=0) + + # streaming voc + while (mel_streaming.shape[0] >= end and + voc_chunk_id < voc_chunk_num): + if flag == 1: + first_am_et = time.time() + self.first_am_infer = first_am_et - frontend_et + voc_chunk = mel_streaming[start:end, :] + voc_chunk = paddle.to_tensor(voc_chunk) + sub_wav = self.voc_inference(voc_chunk) + + sub_wav = self.depadding(sub_wav, voc_chunk_num, + voc_chunk_id, voc_block, + voc_pad, voc_upsample) + if flag == 1: + first_voc_et = time.time() + self.first_voc_infer = first_voc_et - first_am_et + self.first_response_time = first_voc_et - frontend_st + flag = 0 + + yield sub_wav + + voc_chunk_id += 1 + start = max(0, voc_chunk_id * voc_block - voc_pad) + end = min((voc_chunk_id + 1) * voc_block + voc_pad, + mel_len) + else: - # multi speaker - if am_dataset in {"aishell3", "vctk"}: - mel = self.am_inference( - part_phone_ids, spk_id=paddle.to_tensor(spk_id)) - else: - mel = self.am_inference(part_phone_ids) - am_et = time.time() - - # voc streaming - voc_upsample = self.voc_config.n_shift - mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") - chunk_num = len(mel_chunks) - voc_st = time.time() - for i, mel_chunk in enumerate(mel_chunks): - sub_wav = self.voc_inference(mel_chunk) - front_pad = min(i * voc_block, voc_pad) - - if i == 0: - sub_wav = sub_wav[:voc_block * voc_upsample] - elif i == chunk_num - 1: - sub_wav = sub_wav[front_pad * voc_upsample:] - else: - sub_wav = sub_wav[front_pad * voc_upsample:( - front_pad + voc_block) * voc_upsample] - - yield sub_wav + logger.error( + "Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts." + ) + + self.final_response_time = time.time() - frontend_st class TTSEngine(BaseEngine): @@ -116,11 +473,18 @@ class TTSEngine(BaseEngine): super(TTSEngine, self).__init__() def init(self, config: dict) -> bool: - self.executor = TTSServerExecutor() self.config = config - assert "fastspeech2_csmsc" in config.am and ( - config.voc == "hifigan_csmsc-zh" or config.voc == "mb_melgan_csmsc" + assert ( + config.am == "fastspeech2_csmsc" or + config.am == "fastspeech2_cnndecoder_csmsc" + ) and ( + config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc" ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' + + assert ( + config.voc_block > 0 and config.voc_pad > 0 + ), "Please set correct voc_block and voc_pad, they should be more than 0." + try: if self.config.device: self.device = self.config.device @@ -135,6 +499,9 @@ class TTSEngine(BaseEngine): (self.device)) return False + self.executor = TTSServerExecutor(config.am_block, config.am_pad, + config.voc_block, config.voc_pad) + try: self.executor._init_from_path( am=self.config.am, @@ -155,15 +522,42 @@ class TTSEngine(BaseEngine): (self.device)) return False - self.am_block = self.config.am_block - self.am_pad = self.config.am_pad - self.voc_block = self.config.voc_block - self.voc_pad = self.config.voc_pad - logger.info("Initialize TTS server engine successfully on device: %s." % (self.device)) + + # warm up + try: + self.warm_up() + except Exception as e: + logger.error("Failed to warm up on tts engine.") + return False + return True + def warm_up(self): + """warm up + """ + if self.config.lang == 'zh': + sentence = "您好,欢迎使用语音合成服务。" + if self.config.lang == 'en': + sentence = "Hello and welcome to the speech synthesis service." + logger.info( + "*******************************warm up ********************************" + ) + for i in range(3): + for wav in self.executor.infer( + text=sentence, + lang=self.config.lang, + am=self.config.am, + spk_id=0, ): + logger.info( + f"The first response time of the {i} warm up: {self.executor.first_response_time} s" + ) + break + logger.info( + "**********************************************************************" + ) + def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): # Convert byte to text if text_bese64: @@ -195,18 +589,14 @@ class TTSEngine(BaseEngine): wav_base64: The base64 format of the synthesized audio. """ - lang = self.config.lang wav_list = [] for wav in self.executor.infer( text=sentence, - lang=lang, + lang=self.config.lang, am=self.config.am, - spk_id=spk_id, - am_block=self.am_block, - am_pad=self.am_pad, - voc_block=self.voc_block, - voc_pad=self.voc_pad): + spk_id=spk_id, ): + # wav type: float32, convert to pcm (base64) wav = float2pcm(wav) # float32 to int16 wav_bytes = wav.tobytes() # to bytes @@ -216,5 +606,14 @@ class TTSEngine(BaseEngine): yield wav_base64 wav_all = np.concatenate(wav_list, axis=0) - logger.info("The durations of audio is: {} s".format( - len(wav_all) / self.executor.am_config.fs)) + duration = len(wav_all) / self.executor.am_config.fs + logger.info(f"sentence: {sentence}") + logger.info(f"The durations of audio is: {duration} s") + logger.info( + f"first response time: {self.executor.first_response_time} s") + logger.info( + f"final response time: {self.executor.final_response_time} s") + logger.info(f"RTF: {self.executor.final_response_time / duration}") + logger.info( + f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s," + ) diff --git a/paddlespeech/server/utils/util.py b/paddlespeech/server/utils/util.py index 0fe70849..72ee0060 100644 --- a/paddlespeech/server/utils/util.py +++ b/paddlespeech/server/utils/util.py @@ -52,6 +52,10 @@ def get_chunks(data, block_size, pad_size, step): Returns: list: chunks list """ + + if block_size == -1: + return [data] + if step == "am": data_len = data.shape[1] elif step == "voc": From 00a6236fe2c0affa3093551c1d88f0a92b2d0a42 Mon Sep 17 00:00:00 2001 From: lym0302 Date: Mon, 18 Apr 2022 17:31:47 +0800 Subject: [PATCH 12/31] remove test code, test=doc --- paddlespeech/server/tests/tts/infer/run.sh | 62 -- .../server/tests/tts/infer/test_online_tts.py | 610 ------------------ 2 files changed, 672 deletions(-) delete mode 100644 paddlespeech/server/tests/tts/infer/run.sh delete mode 100644 paddlespeech/server/tests/tts/infer/test_online_tts.py diff --git a/paddlespeech/server/tests/tts/infer/run.sh b/paddlespeech/server/tests/tts/infer/run.sh deleted file mode 100644 index 3733c3fb..00000000 --- a/paddlespeech/server/tests/tts/infer/run.sh +++ /dev/null @@ -1,62 +0,0 @@ -model_path=~/.paddlespeech/models/ -am_model_dir=$model_path/fastspeech2_csmsc-zh/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0/ -voc_model_dir=$model_path/mb_melgan_csmsc-zh/mb_melgan_csmsc_ckpt_0.1.1/ -testdata=../../../../t2s/exps/csmsc_test.txt - -# get am file -for file in $(ls $am_model_dir) -do - if [[ $file == *"yaml"* ]]; then - am_config_file=$file - elif [[ $file == *"pdz"* ]]; then - am_ckpt_file=$file - elif [[ $file == *"stat"* ]]; then - am_stat_file=$file - elif [[ $file == *"phone"* ]]; then - phones_dict_file=$file - fi - -done - -# get voc file -for file in $(ls $voc_model_dir) -do - if [[ $file == *"yaml"* ]]; then - voc_config_file=$file - elif [[ $file == *"pdz"* ]]; then - voc_ckpt_file=$file - elif [[ $file == *"stat"* ]]; then - voc_stat_file=$file - fi - -done - - -# run test -# am can choose fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc, where fastspeech2_cnndecoder_csmsc supports streaming inference. -# voc can choose hifigan_csmsc and mb_melgan_csmsc, They can both support streaming inference. -# When am is fastspeech2_cnndecoder_csmsc and am_pad is set to 12, there is no diff between streaming and non-streaming inference results. -# When voc is mb_melgan_csmsc and voc_pad is set to 14, there is no diff between streaming and non-streaming inference results. -# When voc is hifigan_csmsc and voc_pad is set to 20, there is no diff between streaming and non-streaming inference results. - -python test_online_tts.py --am fastspeech2_cnndecoder_csmsc \ - --am_config $am_model_dir/$am_config_file \ - --am_ckpt $am_model_dir/$am_ckpt_file \ - --am_stat $am_model_dir/$am_stat_file \ - --phones_dict $am_model_dir/$phones_dict_file \ - --voc mb_melgan_csmsc \ - --voc_config $voc_model_dir/$voc_config_file \ - --voc_ckpt $voc_model_dir/$voc_ckpt_file \ - --voc_stat $voc_model_dir/$voc_stat_file \ - --lang zh \ - --device cpu \ - --text $testdata \ - --output_dir ./output \ - --log_file ./result.log \ - --am_streaming True \ - --am_pad 12 \ - --am_block 42 \ - --voc_streaming True \ - --voc_pad 14 \ - --voc_block 14 \ - diff --git a/paddlespeech/server/tests/tts/infer/test_online_tts.py b/paddlespeech/server/tests/tts/infer/test_online_tts.py deleted file mode 100644 index eb5fc80b..00000000 --- a/paddlespeech/server/tests/tts/infer/test_online_tts.py +++ /dev/null @@ -1,610 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import logging -import math -import threading -import time -from pathlib import Path - -import numpy as np -import paddle -import soundfile as sf -import yaml -from yacs.config import CfgNode - -from paddlespeech.s2t.utils.dynamic_import import dynamic_import -from paddlespeech.t2s.exps.syn_utils import get_am_inference -from paddlespeech.t2s.exps.syn_utils import get_frontend -from paddlespeech.t2s.exps.syn_utils import get_sentences -from paddlespeech.t2s.exps.syn_utils import get_voc_inference -from paddlespeech.t2s.exps.syn_utils import model_alias -from paddlespeech.t2s.utils import str2bool - -mel_streaming = None -wav_streaming = None -streaming_first_time = 0.0 -streaming_voc_st = 0.0 -sample_rate = 0 - - -def denorm(data, mean, std): - return data * std + mean - - -def get_chunks(data, block_size, pad_size, step): - if step == "am": - data_len = data.shape[1] - elif step == "voc": - data_len = data.shape[0] - else: - print("Please set correct type to get chunks, am or voc") - - chunks = [] - n = math.ceil(data_len / block_size) - for i in range(n): - start = max(0, i * block_size - pad_size) - end = min((i + 1) * block_size + pad_size, data_len) - if step == "am": - chunks.append(data[:, start:end, :]) - elif step == "voc": - chunks.append(data[start:end, :]) - else: - print("Please set correct type to get chunks, am or voc") - return chunks - - -def get_streaming_am_inference(args, am_config): - with open(args.phones_dict, "r") as f: - phn_id = [line.strip().split() for line in f.readlines()] - vocab_size = len(phn_id) - print("vocab_size:", vocab_size) - - am_name = "fastspeech2" - odim = am_config.n_mels - - am_class = dynamic_import(am_name, model_alias) - am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) - am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) - am.eval() - am_mu, am_std = np.load(args.am_stat) - am_mu = paddle.to_tensor(am_mu) - am_std = paddle.to_tensor(am_std) - - return am, am_mu, am_std - - -def init(args): - global sample_rate - # get config - with open(args.am_config) as f: - am_config = CfgNode(yaml.safe_load(f)) - with open(args.voc_config) as f: - voc_config = CfgNode(yaml.safe_load(f)) - - sample_rate = am_config.fs - - # frontend - frontend = get_frontend(args) - - # acoustic model - if args.am == 'fastspeech2_cnndecoder_csmsc': - am, am_mu, am_std = get_streaming_am_inference(args, am_config) - am_infer_info = [am, am_mu, am_std, am_config] - else: - am_inference, am_name, am_dataset = get_am_inference(args, am_config) - am_infer_info = [am_inference, am_name, am_dataset, am_config] - - # vocoder - voc_inference = get_voc_inference(args, voc_config) - voc_infer_info = [voc_inference, voc_config] - - return frontend, am_infer_info, voc_infer_info - - -def get_phone(args, frontend, sentence, merge_sentences, get_tone_ids): - am_name = args.am[:args.am.rindex('_')] - tone_ids = None - - if args.lang == 'zh': - input_ids = frontend.get_input_ids( - sentence, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids) - phone_ids = input_ids["phone_ids"] - if get_tone_ids: - tone_ids = input_ids["tone_ids"] - elif args.lang == 'en': - input_ids = frontend.get_input_ids( - sentence, merge_sentences=merge_sentences) - phone_ids = input_ids["phone_ids"] - else: - print("lang should in {'zh', 'en'}!") - - return phone_ids, tone_ids - - -@paddle.no_grad() -# 生成完整的mel -def gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids): - # 如果是支持流式的AM模型 - if args.am == 'fastspeech2_cnndecoder_csmsc': - am, am_mu, am_std, am_config = am_infer_info - orig_hs, h_masks = am.encoder_infer(part_phone_ids) - if args.am_streaming: - am_pad = args.am_pad - am_block = args.am_block - hss = get_chunks(orig_hs, am_block, am_pad, "am") - chunk_num = len(hss) - mel_list = [] - for i, hs in enumerate(hss): - before_outs, _ = am.decoder(hs) - after_outs = before_outs + am.postnet( - before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) - normalized_mel = after_outs[0] - sub_mel = denorm(normalized_mel, am_mu, am_std) - # clip output part of pad - if i == 0: - sub_mel = sub_mel[:-am_pad] - elif i == chunk_num - 1: - # 最后一块的右侧一定没有 pad 够 - sub_mel = sub_mel[am_pad:] - else: - # 倒数几块的右侧也可能没有 pad 够 - sub_mel = sub_mel[am_pad:(am_block + am_pad) - - sub_mel.shape[0]] - mel_list.append(sub_mel) - mel = paddle.concat(mel_list, axis=0) - - else: - orig_hs, h_masks = am.encoder_infer(part_phone_ids) - before_outs, _ = am.decoder(orig_hs) - after_outs = before_outs + am.postnet( - before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) - normalized_mel = after_outs[0] - mel = denorm(normalized_mel, am_mu, am_std) - - else: - am_inference, am_name, am_dataset, am_config = am_infer_info - mel = am_inference(part_phone_ids) - - return mel - - -@paddle.no_grad() -def streaming_voc_infer(args, voc_infer_info, mel_len): - global mel_streaming - global streaming_first_time - global wav_streaming - voc_inference, voc_config = voc_infer_info - block = args.voc_block - pad = args.voc_pad - upsample = voc_config.n_shift - wav_list = [] - flag = 1 - - valid_start = 0 - valid_end = min(valid_start + block, mel_len) - actual_start = 0 - actual_end = min(valid_end + pad, mel_len) - mel_chunk = mel_streaming[actual_start:actual_end, :] - - while valid_end <= mel_len: - sub_wav = voc_inference(mel_chunk) - if flag == 1: - streaming_first_time = time.time() - flag = 0 - - # get valid wav - start = valid_start - actual_start - if valid_end == mel_len: - sub_wav = sub_wav[start * upsample:] - wav_list.append(sub_wav) - break - else: - end = start + block - sub_wav = sub_wav[start * upsample:end * upsample] - wav_list.append(sub_wav) - - # generate new mel chunk - valid_start = valid_end - valid_end = min(valid_start + block, mel_len) - if valid_start - pad < 0: - actual_start = 0 - else: - actual_start = valid_start - pad - actual_end = min(valid_end + pad, mel_len) - mel_chunk = mel_streaming[actual_start:actual_end, :] - - wav = paddle.concat(wav_list, axis=0) - wav_streaming = wav - - -@paddle.no_grad() -# 非流式AM / 流式AM + 非流式Voc -def am_nonstreaming_voc(args, am_infer_info, voc_infer_info, part_phone_ids, - part_tone_ids): - mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids) - am_infer_time = time.time() - voc_inference, voc_config = voc_infer_info - wav = voc_inference(mel) - first_response_time = time.time() - final_response_time = first_response_time - voc_infer_time = first_response_time - - return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav - - -@paddle.no_grad() -# 非流式AM + 流式Voc -def nonstreaming_am_streaming_voc(args, am_infer_info, voc_infer_info, - part_phone_ids, part_tone_ids): - global mel_streaming - global streaming_first_time - global wav_streaming - - mel = gen_mel(args, am_infer_info, part_phone_ids, part_tone_ids) - am_infer_time = time.time() - - # voc streaming - mel_streaming = mel - mel_len = mel.shape[0] - streaming_voc_infer(args, voc_infer_info, mel_len) - first_response_time = streaming_first_time - wav = wav_streaming - final_response_time = time.time() - voc_infer_time = final_response_time - - return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav - - -@paddle.no_grad() -# 流式AM + 流式 Voc -def streaming_am_streaming_voc(args, am_infer_info, voc_infer_info, - part_phone_ids, part_tone_ids): - global mel_streaming - global streaming_first_time - global wav_streaming - global streaming_voc_st - mel_streaming = None - #用来表示开启流式voc的线程 - flag = 1 - - am, am_mu, am_std, am_config = am_infer_info - orig_hs, h_masks = am.encoder_infer(part_phone_ids) - mel_len = orig_hs.shape[1] - am_block = args.am_block - am_pad = args.am_pad - hss = get_chunks(orig_hs, am_block, am_pad, "am") - chunk_num = len(hss) - - for i, hs in enumerate(hss): - before_outs, _ = am.decoder(hs) - after_outs = before_outs + am.postnet( - before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) - normalized_mel = after_outs[0] - sub_mel = denorm(normalized_mel, am_mu, am_std) - # clip output part of pad - if i == 0: - sub_mel = sub_mel[:-am_pad] - mel_streaming = sub_mel - elif i == chunk_num - 1: - # 最后一块的右侧一定没有 pad 够 - sub_mel = sub_mel[am_pad:] - mel_streaming = paddle.concat([mel_streaming, sub_mel]) - am_infer_time = time.time() - else: - # 倒数几块的右侧也可能没有 pad 够 - sub_mel = sub_mel[am_pad:(am_block + am_pad) - sub_mel.shape[0]] - mel_streaming = paddle.concat([mel_streaming, sub_mel]) - - if flag and mel_streaming.shape[0] > args.voc_block + args.voc_pad: - t = threading.Thread( - target=streaming_voc_infer, - args=(args, voc_infer_info, mel_len, )) - t.start() - streaming_voc_st = time.time() - flag = 0 - - t.join() - final_response_time = time.time() - voc_infer_time = final_response_time - first_response_time = streaming_first_time - wav = wav_streaming - - return am_infer_time, voc_infer_time, first_response_time, final_response_time, wav - - -def warm_up(args, logger, frontend, am_infer_info, voc_infer_info): - global sample_rate - logger.info( - "Before the formal test, we test a few texts to make the inference speed more stable." - ) - if args.lang == 'zh': - sentence = "您好,欢迎使用语音合成服务。" - if args.lang == 'en': - sentence = "Hello and welcome to the speech synthesis service." - - if args.voc_streaming: - if args.am_streaming: - infer_func = streaming_am_streaming_voc - else: - infer_func = nonstreaming_am_streaming_voc - else: - infer_func = am_nonstreaming_voc - - merge_sentences = True - get_tone_ids = False - for i in range(5): # 推理5次 - st = time.time() - phone_ids, tone_ids = get_phone(args, frontend, sentence, - merge_sentences, get_tone_ids) - part_phone_ids = phone_ids[0] - if tone_ids: - part_tone_ids = tone_ids[0] - else: - part_tone_ids = None - - am_infer_time, voc_infer_time, first_response_time, final_response_time, wav = infer_func( - args, am_infer_info, voc_infer_info, part_phone_ids, part_tone_ids) - wav = wav.numpy() - duration = wav.size / sample_rate - logger.info( - f"sentence: {sentence}; duration: {duration} s; first response time: {first_response_time - st} s; final response time: {final_response_time - st} s" - ) - - -def evaluate(args, logger, frontend, am_infer_info, voc_infer_info): - global sample_rate - sentences = get_sentences(args) - - output_dir = Path(args.output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - get_tone_ids = False - merge_sentences = True - - # choose infer function - if args.voc_streaming: - if args.am_streaming: - infer_func = streaming_am_streaming_voc - else: - infer_func = nonstreaming_am_streaming_voc - else: - infer_func = am_nonstreaming_voc - - final_up_duration = 0.0 - sentence_count = 0 - front_time_list = [] - am_time_list = [] - voc_time_list = [] - first_response_list = [] - final_response_list = [] - sentence_length_list = [] - duration_list = [] - - for utt_id, sentence in sentences: - # front - front_st = time.time() - phone_ids, tone_ids = get_phone(args, frontend, sentence, - merge_sentences, get_tone_ids) - part_phone_ids = phone_ids[0] - if tone_ids: - part_tone_ids = tone_ids[0] - else: - part_tone_ids = None - front_et = time.time() - front_time = front_et - front_st - - am_st = time.time() - am_infer_time, voc_infer_time, first_response_time, final_response_time, wav = infer_func( - args, am_infer_info, voc_infer_info, part_phone_ids, part_tone_ids) - am_time = am_infer_time - am_st - if args.voc_streaming and args.am_streaming: - voc_time = voc_infer_time - streaming_voc_st - else: - voc_time = voc_infer_time - am_infer_time - - first_response = first_response_time - front_st - final_response = final_response_time - front_st - - wav = wav.numpy() - duration = wav.size / sample_rate - sf.write( - str(output_dir / (utt_id + ".wav")), wav, samplerate=sample_rate) - print(f"{utt_id} done!") - - sentence_count += 1 - front_time_list.append(front_time) - am_time_list.append(am_time) - voc_time_list.append(voc_time) - first_response_list.append(first_response) - final_response_list.append(final_response) - sentence_length_list.append(len(sentence)) - duration_list.append(duration) - - logger.info( - f"uttid: {utt_id}; sentence: '{sentence}'; front time: {front_time} s; am time: {am_time} s; voc time: {voc_time} s; \ - first response time: {first_response} s; final response time: {final_response} s; audio duration: {duration} s;" - ) - - if final_response > duration: - final_up_duration += 1 - - all_time_sum = sum(final_response_list) - front_rate = sum(front_time_list) / all_time_sum - am_rate = sum(am_time_list) / all_time_sum - voc_rate = sum(voc_time_list) / all_time_sum - rtf = all_time_sum / sum(duration_list) - - logger.info( - f"The length of test text information, test num: {sentence_count}; text num: {sum(sentence_length_list)}; min: {min(sentence_length_list)}; max: {max(sentence_length_list)}; avg: {sum(sentence_length_list)/len(sentence_length_list)}" - ) - logger.info( - f"duration information, min: {min(duration_list)}; max: {max(duration_list)}; avg: {sum(duration_list) / len(duration_list)}; sum: {sum(duration_list)}" - ) - logger.info( - f"Front time information: min: {min(front_time_list)} s; max: {max(front_time_list)} s; avg: {sum(front_time_list)/len(front_time_list)} s; ratio: {front_rate * 100}%" - ) - logger.info( - f"AM time information: min: {min(am_time_list)} s; max: {max(am_time_list)} s; avg: {sum(am_time_list)/len(am_time_list)} s; ratio: {am_rate * 100}%" - ) - logger.info( - f"Vocoder time information: min: {min(voc_time_list)} s, max: {max(voc_time_list)} s; avg: {sum(voc_time_list)/len(voc_time_list)} s; ratio: {voc_rate * 100}%" - ) - logger.info( - f"first response time information: min: {min(first_response_list)} s; max: {max(first_response_list)} s; avg: {sum(first_response_list)/len(first_response_list)} s" - ) - logger.info( - f"final response time information: min: {min(final_response_list)} s; max: {max(final_response_list)} s; avg: {sum(final_response_list)/len(final_response_list)} s" - ) - logger.info(f"RTF is: {rtf}") - logger.info( - f"The number of final_response is greater than duration is {final_up_duration}, ratio: {final_up_duration / sentence_count}%" - ) - - -def parse_args(): - # parse args and config and redirect to train_sp - parser = argparse.ArgumentParser( - description="Synthesize with acoustic model & vocoder") - # acoustic model - parser.add_argument( - '--am', - type=str, - default='fastspeech2_csmsc', - choices=['fastspeech2_csmsc', 'fastspeech2_cnndecoder_csmsc'], - help='Choose acoustic model type of tts task. where fastspeech2_cnndecoder_csmsc supports streaming inference' - ) - - parser.add_argument( - '--am_config', - type=str, - default=None, - help='Config of acoustic model. Use deault config when it is None.') - parser.add_argument( - '--am_ckpt', - type=str, - default=None, - help='Checkpoint file of acoustic model.') - parser.add_argument( - "--am_stat", - type=str, - default=None, - help="mean and standard deviation used to normalize spectrogram when training acoustic model." - ) - parser.add_argument( - "--phones_dict", type=str, default=None, help="phone vocabulary file.") - parser.add_argument( - "--tones_dict", type=str, default=None, help="tone vocabulary file.") - # vocoder - parser.add_argument( - '--voc', - type=str, - default='mb_melgan_csmsc', - choices=['mb_melgan_csmsc', 'hifigan_csmsc'], - help='Choose vocoder type of tts task.') - parser.add_argument( - '--voc_config', - type=str, - default=None, - help='Config of voc. Use deault config when it is None.') - parser.add_argument( - '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') - parser.add_argument( - "--voc_stat", - type=str, - default=None, - help="mean and standard deviation used to normalize spectrogram when training voc." - ) - # other - parser.add_argument( - '--lang', - type=str, - default='zh', - choices=['zh', 'en'], - help='Choose model language. zh or en') - - parser.add_argument( - "--device", type=str, default='cpu', help="set cpu or gpu:id") - - parser.add_argument( - "--text", - type=str, - default="./csmsc_test.txt", - help="text to synthesize, a 'utt_id sentence' pair per line.") - parser.add_argument("--output_dir", type=str, help="output dir.") - parser.add_argument( - "--log_file", type=str, default="result.log", help="log file.") - - parser.add_argument( - "--am_streaming", - type=str2bool, - default=False, - help="whether use streaming acoustic model") - - parser.add_argument("--am_pad", type=int, default=12, help="am pad size.") - - parser.add_argument( - "--am_block", type=int, default=42, help="am block size.") - - parser.add_argument( - "--voc_streaming", - type=str2bool, - default=False, - help="whether use streaming vocoder model") - - parser.add_argument("--voc_pad", type=int, default=14, help="voc pad size.") - - parser.add_argument( - "--voc_block", type=int, default=14, help="voc block size.") - - args = parser.parse_args() - return args - - -def main(): - args = parse_args() - paddle.set_device(args.device) - if args.am_streaming: - assert (args.am == 'fastspeech2_cnndecoder_csmsc') - - logger = logging.getLogger() - fhandler = logging.FileHandler(filename=args.log_file, mode='w') - formatter = logging.Formatter( - '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' - ) - fhandler.setFormatter(formatter) - logger.addHandler(fhandler) - logger.setLevel(logging.DEBUG) - - # set basic information - logger.info( - f"AM: {args.am}; Vocoder: {args.voc}; device: {args.device}; am streaming: {args.am_streaming}; voc streaming: {args.voc_streaming}" - ) - logger.info( - f"am pad size: {args.am_pad}; am block size: {args.am_block}; voc pad size: {args.voc_pad}; voc block size: {args.voc_block};" - ) - - # get information about model - frontend, am_infer_info, voc_infer_info = init(args) - logger.info( - "************************ warm up *********************************") - warm_up(args, logger, frontend, am_infer_info, voc_infer_info) - logger.info( - "************************ normal test *******************************") - evaluate(args, logger, frontend, am_infer_info, voc_infer_info) - - -if __name__ == "__main__": - main() From 05a8a4b5fccec0fe549132717f24d25c3240b04f Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 18 Apr 2022 17:11:49 +0800 Subject: [PATCH 13/31] add connection stability, test=doc --- .../server/engine/asr/online/asr_engine.py | 109 ++++++++++++++++-- .../server/engine/asr/online/ctc_search.py | 10 ++ paddlespeech/server/ws/asr_socket.py | 37 +++--- 3 files changed, 121 insertions(+), 35 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index c13b2f6d..696d223a 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -83,8 +83,10 @@ pretrained_models = { class PaddleASRConnectionHanddler: def __init__(self, asr_engine): super().__init__() + logger.info("create an paddle asr connection handler to process the websocket connection") self.config = asr_engine.config self.model_config = asr_engine.executor.config + self.model = asr_engine.executor.model self.asr_engine = asr_engine self.init() @@ -149,6 +151,10 @@ class PaddleASRConnectionHanddler: assert(len(self.cached_feat.shape) == 3) self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1) + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + num_frames = x_chunk.shape[1] self.num_frames += num_frames self.remained_wav = self.remained_wav[self.n_shift * num_frames:] @@ -165,16 +171,17 @@ class PaddleASRConnectionHanddler: self.subsampling_cache = None self.elayers_output_cache = None self.conformer_cnn_cache = None - self.encoder_outs_ = None + self.encoder_out = None self.cached_feat = None self.remained_wav = None self.offset = 0 self.num_samples = 0 - + self.device = None + self.hyps = [] self.num_frames = 0 self.chunk_num = 0 self.global_frame_offset = 0 - self.result = [] + self.result_transcripts = [''] def decode(self, is_finished=False): if "deepspeech2online" in self.model_type: @@ -187,7 +194,6 @@ class PaddleASRConnectionHanddler: self.advance_decoding(is_finished) self.update_result() - return self.result_transcripts[0] except Exception as e: logger.exception(e) else: @@ -203,16 +209,26 @@ class PaddleASRConnectionHanddler: subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size + cached_feature_num = context - subsampling # processed chunk feature cached for next chunk # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context + if self.cached_feat is None: + logger.info("no audio feat, please input more pcm data") + return + num_frames = self.cached_feat.shape[1] logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames") # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: + logger.info(f"frame feat num is less than {decoding_window}, please input more pcm data") return None, None - + + if num_frames < context: + logger.info("flast {num_frames} is less than context {context} frames, and we cannot do model forward") + return None, None + logger.info("start to do model forward") required_cache_size = decoding_chunk_size * num_decoding_left_chunks outputs = [] @@ -242,14 +258,18 @@ class PaddleASRConnectionHanddler: # update the offset self.offset += y.shape[1] + logger.info(f"output size: {len(outputs)}") ys = paddle.cat(outputs, 1) - masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) - masks = masks.unsqueeze(1) + if self.encoder_out is None: + self.encoder_out = ys + else: + self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) + # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) + # masks = masks.unsqueeze(1) # get the ctc probs ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) - # self.searcher.search(xs, ctc_probs, xs.place) self.searcher.search(None, ctc_probs, self.cached_feat.place) @@ -260,7 +280,8 @@ class PaddleASRConnectionHanddler: self.cached_feat = None else: assert self.cached_feat.shape[0] == 1 - self.cached_feat = self.cached_feat[0,end:,:].unsqueeze(0) + assert end >= cached_feature_num + self.cached_feat = self.cached_feat[0,end - cached_feature_num:,:].unsqueeze(0) assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" # ys for rescoring @@ -274,9 +295,75 @@ class PaddleASRConnectionHanddler: ] self.result_tokenids = [hyp for hyp in hyps] + def get_result(self): + if len(self.result_transcripts) > 0: + return self.result_transcripts[0] + else: + return '' + def rescoring(self): - pass + logger.info("rescoring the final result") + if "attention_rescoring" != self.ctc_decode_config.decoding_method: + return + + self.searcher.finalize_search() + self.update_result() + + beam_size = self.ctc_decode_config.beam_size + hyps = self.searcher.get_hyps() + if hyps is None or len(hyps) == 0: + return + + # assert len(hyps) == beam_size + paddle.save(self.encoder_out, "encoder.out") + hyp_list = [] + for hyp in hyps: + hyp_content = hyp[0] + # Prevent the hyp is empty + if len(hyp_content) == 0: + hyp_content = (self.model.ctc.blank_id, ) + hyp_content = paddle.to_tensor( + hyp_content, place=self.device, dtype=paddle.long) + hyp_list.append(hyp_content) + hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) + hyps_lens = paddle.to_tensor( + [len(hyp[0]) for hyp in hyps], place=self.device, + dtype=paddle.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, + self.model.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + + encoder_out = self.encoder_out.repeat(beam_size, 1, 1) + encoder_mask = paddle.ones( + (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) + decoder_out, _ = self.model.decoder( + encoder_out, encoder_mask, hyps_pad, + hyps_lens) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain + decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) + decoder_out = decoder_out.numpy() + + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. + score += decoder_out[i][len(hyp[0])][self.model.eos] + # add ctc score (which in ln domain) + score += hyp[1] * self.ctc_decode_config.ctc_weight + if score > best_score: + best_score = score + best_index = i + # update the one best result + logger.info(f"best index: {best_index}") + self.hyps = [hyps[best_index][0]] + self.update_result() + # return hyps[best_index][0] @@ -552,7 +639,7 @@ class ASRServerExecutor(ASRExecutor): subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size - + # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context num_frames = xs.shape[1] diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index bf4c4b30..c3822b5c 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -110,6 +110,11 @@ class CTCPrefixBeamSearch: return [self.hyps[0][0]] def get_hyps(self): + """Return the search hyps + + Returns: + list: return the search hyps + """ return self.hyps def reset(self): @@ -117,3 +122,8 @@ class CTCPrefixBeamSearch: """ self.cur_hyps = None self.hyps = None + + def finalize_search(self): + """do nothing in ctc_prefix_beam_search + """ + pass diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 04807e5c..ae7c5eb4 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -13,16 +13,15 @@ # limitations under the License. import json -import numpy as np from fastapi import APIRouter from fastapi import WebSocket from fastapi import WebSocketDisconnect from starlette.websockets import WebSocketState as WebSocketState +from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.vad import VADAudio -from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler router = APIRouter() @@ -73,13 +72,17 @@ async def websocket_endpoint(websocket: WebSocket): connection_handler = PaddleASRConnectionHanddler(asr_engine) await websocket.send_json(resp) elif message['signal'] == 'end': - engine_pool = get_engine_pool() - asr_engine = engine_pool['asr'] # reset single engine for an new connection - asr_results = connection_handler.decode(is_finished=True) + connection_handler.decode(is_finished=True) + connection_handler.rescoring() + asr_results = connection_handler.get_result() connection_handler.reset() - asr_engine.reset() - resp = {"status": "ok", "signal": "finished", 'asr_results': asr_results} + + resp = { + "status": "ok", + "signal": "finished", + 'asr_results': asr_results + } await websocket.send_json(resp) break else: @@ -87,25 +90,11 @@ async def websocket_endpoint(websocket: WebSocket): await websocket.send_json(resp) elif "bytes" in message: message = message["bytes"] - engine_pool = get_engine_pool() - asr_engine = engine_pool['asr'] - asr_results = "" + connection_handler.extract_feat(message) - asr_results = connection_handler.decode(is_finished=False) - # connection_handler. - # frames = chunk_buffer.frame_generator(message) - # for frame in frames: - # # get the pcm data from the bytes - # samples = np.frombuffer(frame.bytes, dtype=np.int16) - # sample_rate = asr_engine.config.sample_rate - # x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - # sample_rate) - # asr_engine.run(x_chunk, x_chunk_lens) - # asr_results = asr_engine.postprocess() + connection_handler.decode(is_finished=False) + asr_results = connection_handler.get_result() - # # connection accept the sample data frame by frame - - # asr_results = asr_engine.postprocess() resp = {'asr_results': asr_results} print("\n") await websocket.send_json(resp) From 5acb0b5252e77018fdca05435c97638ac48f5d6a Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Mon, 18 Apr 2022 21:46:57 +0800 Subject: [PATCH 14/31] fix the websocket chunk edge bug, test=doc --- .../server/engine/asr/online/asr_engine.py | 121 ++++++++++-------- paddlespeech/server/ws/asr_socket.py | 1 - 2 files changed, 66 insertions(+), 56 deletions(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 696d223a..a8e25f4b 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -60,9 +60,9 @@ pretrained_models = { }, "conformer2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', 'md5': - 'b450d5dfaea0ac227c595ce58d18b637', + '0ac93d390552336f2a906aec9e33c5fa', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -78,12 +78,19 @@ pretrained_models = { }, } -# ASR server connection process class +# ASR server connection process class class PaddleASRConnectionHanddler: def __init__(self, asr_engine): + """Init a Paddle ASR Connection Handler instance + + Args: + asr_engine (ASREngine): the global asr engine + """ super().__init__() - logger.info("create an paddle asr connection handler to process the websocket connection") + logger.info( + "create an paddle asr connection handler to process the websocket connection" + ) self.config = asr_engine.config self.model_config = asr_engine.executor.config self.model = asr_engine.executor.model @@ -98,24 +105,26 @@ class PaddleASRConnectionHanddler: pass elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: self.sample_rate = self.asr_engine.executor.sample_rate - + # acoustic model self.model = self.asr_engine.executor.model - + # tokens to text self.text_feature = self.asr_engine.executor.text_feature - - # ctc decoding + + # ctc decoding config self.ctc_decode_config = self.asr_engine.executor.config.decode self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) - # extract fbank + # extract feat, new only fbank in conformer model self.preprocess_conf = self.model_config.preprocess_config self.preprocess_args = {"train": False} self.preprocessing = Transformation(self.preprocess_conf) + + # frame window samples length and frame shift samples length self.win_length = self.preprocess_conf.process[0]['win_length'] self.n_shift = self.preprocess_conf.process[0]['n_shift'] - + def extract_feat(self, samples): if "deepspeech2online" in self.model_type: pass @@ -123,10 +132,10 @@ class PaddleASRConnectionHanddler: logger.info("Online ASR extract the feat") samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 - + logger.info(f"This package receive {samples.shape[0]} pcm data") self.num_samples += samples.shape[0] - + # self.reamined_wav stores all the samples, # include the original remained_wav and this package samples if self.remained_wav is None: @@ -141,19 +150,21 @@ class PaddleASRConnectionHanddler: return 0 # fbank - x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) + x_chunk = self.preprocessing(self.remained_wav, + **self.preprocess_args) x_chunk = paddle.to_tensor( x_chunk, dtype="float32").unsqueeze(axis=0) if self.cached_feat is None: self.cached_feat = x_chunk else: - assert(len(x_chunk.shape) == 3) - assert(len(self.cached_feat.shape) == 3) - self.cached_feat = paddle.concat([self.cached_feat, x_chunk], axis=1) - + assert (len(x_chunk.shape) == 3) + assert (len(self.cached_feat.shape) == 3) + self.cached_feat = paddle.concat( + [self.cached_feat, x_chunk], axis=1) + # set the feat device if self.device is None: - self.device = self.cached_feat.place + self.device = self.cached_feat.place num_frames = x_chunk.shape[1] self.num_frames += num_frames @@ -161,7 +172,7 @@ class PaddleASRConnectionHanddler: logger.info( f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" - ) + ) logger.info( f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" ) @@ -209,24 +220,30 @@ class PaddleASRConnectionHanddler: subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size - cached_feature_num = context - subsampling # processed chunk feature cached for next chunk + cached_feature_num = context - subsampling # processed chunk feature cached for next chunk # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") return - + num_frames = self.cached_feat.shape[1] - logger.info(f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames") - + logger.info( + f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" + ) + # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: - logger.info(f"frame feat num is less than {decoding_window}, please input more pcm data") + logger.info( + f"frame feat num is less than {decoding_window}, please input more pcm data" + ) return None, None if num_frames < context: - logger.info("flast {num_frames} is less than context {context} frames, and we cannot do model forward") + logger.info( + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + ) return None, None logger.info("start to do model forward") @@ -235,17 +252,17 @@ class PaddleASRConnectionHanddler: # num_frames - context + 1 ensure that current frame can get context window if is_finished: - # if get the finished chunk, we need process the last context + # if get the finished chunk, we need process the last context left_frames = context else: # we only process decoding_window frames for one chunk - left_frames = decoding_window - + left_frames = decoding_window + # record the end for removing the processed feat end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) - + self.chunk_num += 1 chunk_xs = self.cached_feat[:, cur:end, :] (y, self.subsampling_cache, self.elayers_output_cache, @@ -257,35 +274,31 @@ class PaddleASRConnectionHanddler: # update the offset self.offset += y.shape[1] - - logger.info(f"output size: {len(outputs)}") + ys = paddle.cat(outputs, 1) if self.encoder_out is None: - self.encoder_out = ys + self.encoder_out = ys else: self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) - # masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) - # masks = masks.unsqueeze(1) - + # get the ctc probs ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) self.searcher.search(None, ctc_probs, self.cached_feat.place) - + self.hyps = self.searcher.get_one_best_hyps() + assert self.cached_feat.shape[0] == 1 + assert end >= cached_feature_num + self.cached_feat = self.cached_feat[0, end - + cached_feature_num:, :].unsqueeze(0) + assert len( + self.cached_feat.shape + ) == 3, f"current cache feat shape is: {self.cached_feat.shape}" - # remove the processed feat - if end == num_frames: - self.cached_feat = None - else: - assert self.cached_feat.shape[0] == 1 - assert end >= cached_feature_num - self.cached_feat = self.cached_feat[0,end - cached_feature_num:,:].unsqueeze(0) - assert len(self.cached_feat.shape) == 3, f"current cache feat shape is: {self.cached_feat.shape}" - - # ys for rescoring - # return ys, masks + logger.info( + f"This connection handler encoder out shape: {self.encoder_out.shape}" + ) def update_result(self): logger.info("update the final result") @@ -304,8 +317,8 @@ class PaddleASRConnectionHanddler: def rescoring(self): logger.info("rescoring the final result") if "attention_rescoring" != self.ctc_decode_config.decoding_method: - return - + return + self.searcher.finalize_search() self.update_result() @@ -363,8 +376,6 @@ class PaddleASRConnectionHanddler: logger.info(f"best index: {best_index}") self.hyps = [hyps[best_index][0]] self.update_result() - # return hyps[best_index][0] - class ASRServerExecutor(ASRExecutor): @@ -409,9 +420,9 @@ class ASRServerExecutor(ASRExecutor): logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path - self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" - # self.cfg_path = os.path.join(res_path, - # pretrained_models[tag]['cfg_path']) + + self.cfg_path = os.path.join(res_path, + pretrained_models[tag]['cfg_path']) self.am_model = os.path.join(res_path, pretrained_models[tag]['model']) @@ -639,7 +650,7 @@ class ASRServerExecutor(ASRExecutor): subsampling = self.model.encoder.embed.subsampling_rate context = self.model.encoder.embed.right_context + 1 stride = subsampling * decoding_chunk_size - + # decoding window for model decoding_window = (decoding_chunk_size - 1) * subsampling + context num_frames = xs.shape[1] diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index ae7c5eb4..82b05bc5 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -96,7 +96,6 @@ async def websocket_endpoint(websocket: WebSocket): asr_results = connection_handler.get_result() resp = {'asr_results': asr_results} - print("\n") await websocket.send_json(resp) except WebSocketDisconnect: pass From 40dde22fc48f41cffdace68847ccbeb00cc1cef4 Mon Sep 17 00:00:00 2001 From: lym0302 Date: Tue, 19 Apr 2022 12:59:48 +0800 Subject: [PATCH 15/31] code format, test=doc --- .../server/engine/tts/online/tts_engine.py | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/paddlespeech/server/engine/tts/online/tts_engine.py b/paddlespeech/server/engine/tts/online/tts_engine.py index 8e76225d..a84644e7 100644 --- a/paddlespeech/server/engine/tts/online/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/tts_engine.py @@ -127,33 +127,40 @@ class TTSServerExecutor(TTSExecutor): self.voc_block = voc_block self.voc_pad = voc_pad - def get_model_info(self, step, model_name, ckpt, stat): + def get_model_info(self, + field: str, + model_name: str, + ckpt: Optional[os.PathLike], + stat: Optional[os.PathLike]): """get model information Args: - step (string): am or voc - model_name (string): model type, support fastspeech2, higigan, mb_melgan - ckpt (string): ckpt file - stat (string): stat file, including mean and standard deviation + field (str): am or voc + model_name (str): model type, support fastspeech2, higigan, mb_melgan + ckpt (Optional[os.PathLike]): ckpt file + stat (Optional[os.PathLike]): stat file, including mean and standard deviation Returns: - model, model_mu, model_std + [module]: model module + [Tensor]: mean + [Tensor]: standard deviation """ + model_class = dynamic_import(model_name, model_alias) - if step == "am": + if field == "am": odim = self.am_config.n_mels model = model_class( idim=self.vocab_size, odim=odim, **self.am_config["model"]) model.set_state_dict(paddle.load(ckpt)["main_params"]) - elif step == "voc": + elif field == "voc": model = model_class(**self.voc_config["generator_params"]) model.set_state_dict(paddle.load(ckpt)["generator_params"]) model.remove_weight_norm() else: - logger.error("Please set correct step, am or voc") + logger.error("Please set correct field, am or voc") model.eval() model_mu, model_std = np.load(stat) @@ -346,7 +353,8 @@ class TTSServerExecutor(TTSExecutor): voc_block = self.voc_block voc_pad = self.voc_pad voc_upsample = self.voc_config.n_shift - flag = 1 + # first_flag 用于标记首包 + first_flag = 1 get_tone_ids = False merge_sentences = False @@ -376,7 +384,7 @@ class TTSServerExecutor(TTSExecutor): if am == "fastspeech2_csmsc": # am mel = self.am_inference(part_phone_ids) - if flag == 1: + if first_flag == 1: first_am_et = time.time() self.first_am_infer = first_am_et - frontend_et @@ -388,11 +396,11 @@ class TTSServerExecutor(TTSExecutor): sub_wav = self.voc_inference(mel_chunk) sub_wav = self.depadding(sub_wav, voc_chunk_num, i, voc_block, voc_pad, voc_upsample) - if flag == 1: + if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et self.first_response_time = first_voc_et - frontend_st - flag = 0 + first_flag = 0 yield sub_wav @@ -427,9 +435,10 @@ class TTSServerExecutor(TTSExecutor): (mel_streaming, sub_mel), axis=0) # streaming voc + # 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理 while (mel_streaming.shape[0] >= end and voc_chunk_id < voc_chunk_num): - if flag == 1: + if first_flag == 1: first_am_et = time.time() self.first_am_infer = first_am_et - frontend_et voc_chunk = mel_streaming[start:end, :] @@ -439,11 +448,11 @@ class TTSServerExecutor(TTSExecutor): sub_wav = self.depadding(sub_wav, voc_chunk_num, voc_chunk_id, voc_block, voc_pad, voc_upsample) - if flag == 1: + if first_flag == 1: first_voc_et = time.time() self.first_voc_infer = first_voc_et - first_am_et self.first_response_time = first_voc_et - frontend_st - flag = 0 + first_flag = 0 yield sub_wav @@ -470,7 +479,8 @@ class TTSEngine(BaseEngine): def __init__(self, name=None): """Initialize TTS server engine """ - super(TTSEngine, self).__init__() + #super(TTSEngine, self).__init__() + super().__init__() def init(self, config: dict) -> bool: self.config = config From 9e41ac8550b5f53b77ce3656e3561c58e0f25a82 Mon Sep 17 00:00:00 2001 From: lym0302 Date: Tue, 19 Apr 2022 15:51:44 +0800 Subject: [PATCH 16/31] code format, test=doc --- paddlespeech/server/engine/tts/online/tts_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/paddlespeech/server/engine/tts/online/tts_engine.py b/paddlespeech/server/engine/tts/online/tts_engine.py index a84644e7..c9135b88 100644 --- a/paddlespeech/server/engine/tts/online/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/tts_engine.py @@ -479,7 +479,6 @@ class TTSEngine(BaseEngine): def __init__(self, name=None): """Initialize TTS server engine """ - #super(TTSEngine, self).__init__() super().__init__() def init(self, config: dict) -> bool: From 166757703f841ce7a16d67a625bcf5eea7ee6230 Mon Sep 17 00:00:00 2001 From: qingen Date: Tue, 19 Apr 2022 16:30:23 +0800 Subject: [PATCH 17/31] [vec][loss] add NCE Loss from RNNLM, test=doc fix #1717 --- paddlespeech/vector/modules/loss.py | 131 ++++++++++++++++++++++ paddlespeech/vector/utils/vector_utils.py | 8 ++ 2 files changed, 139 insertions(+) diff --git a/paddlespeech/vector/modules/loss.py b/paddlespeech/vector/modules/loss.py index 1c80dda4..015c0dfe 100644 --- a/paddlespeech/vector/modules/loss.py +++ b/paddlespeech/vector/modules/loss.py @@ -91,3 +91,134 @@ class LogSoftmaxWrapper(nn.Layer): predictions = F.log_softmax(predictions, axis=1) loss = self.criterion(predictions, targets) / targets.sum() return loss + + +class NCELoss(nn.Layer): + """Noise Contrastive Estimation loss funtion + + Noise Contrastive Estimation (NCE) is an approximation method that is used to + work around the huge computational cost of large softmax layer. + The basic idea is to convert the prediction problem into classification problem + at training stage. It has been proved that these two criterions converges to + the same minimal point as long as noise distribution is close enough to real one. + + NCE bridges the gap between generative models and discriminative models, + rather than simply speedup the softmax layer. + With NCE, you can turn almost anything into posterior with less effort (I think). + + Refs: + NCE:http://www.cs.helsinki.fi/u/ahyvarin/papers/Gutmann10AISTATS.pdf + Thanks: https://github.com/mingen-pan/easy-to-use-NCE-RNN-for-Pytorch/blob/master/nce.py + + Examples: + Q = Q_from_tokens(output_dim) + NCELoss(Q) + """ + + def __init__(self, Q, noise_ratio=100, Z_offset=9.5): + """Noise Contrastive Estimation loss funtion + + Args: + Q (tensor): prior model, uniform or guassian + noise_ratio (int, optional): noise sampling times. Defaults to 100. + Z_offset (float, optional): scale of post processing the score. Defaults to 9.5. + """ + super(NCELoss, self).__init__() + assert type(noise_ratio) is int + self.Q = paddle.to_tensor(Q, stop_gradient=False) + self.N = self.Q.shape[0] + self.K = noise_ratio + self.Z_offset = Z_offset + + def forward(self, output, target): + """Forward inference + """ + output = paddle.reshape(output, [-1, self.N]) + B = output.shape[0] + noise_idx = self.get_noise(B) + idx = self.get_combined_idx(target, noise_idx) + P_target, P_noise = self.get_prob(idx, output, sep_target=True) + Q_target, Q_noise = self.get_Q(idx) + loss = self.nce_loss(P_target, P_noise, Q_noise, Q_target) + return loss.mean() + + def get_Q(self, idx, sep_target=True): + """Get prior model of batchsize data + """ + idx_size = idx.size + prob_model = paddle.to_tensor( + self.Q.numpy()[paddle.reshape(idx, [-1]).numpy()]) + prob_model = paddle.reshape(prob_model, [idx.shape[0], idx.shape[1]]) + if sep_target: + return prob_model[:, 0], prob_model[:, 1:] + else: + return prob_model + + def get_prob(self, idx, scores, sep_target=True): + """Post processing the score of post model(output of nn) of batchsize data + """ + scores = self.get_scores(idx, scores) + scale = paddle.to_tensor([self.Z_offset], dtype='float32') + scores = paddle.add(scores, -scale) + prob = paddle.exp(scores) + if sep_target: + return prob[:, 0], prob[:, 1:] + else: + return prob + + def get_scores(self, idx, scores): + """Get the score of post model(output of nn) of batchsize data + """ + B, N = scores.shape + K = idx.shape[1] + idx_increment = paddle.to_tensor( + N * paddle.reshape(paddle.arange(B), [B, 1]) * paddle.ones([1, K]), + dtype="int64", + stop_gradient=False) + new_idx = idx_increment + idx + new_scores = paddle.index_select( + paddle.reshape(scores, [-1]), paddle.reshape(new_idx, [-1])) + + return paddle.reshape(new_scores, [B, K]) + + def get_noise(self, batch_size, uniform=True): + """Select noise sample + """ + if uniform: + noise = np.random.randint(self.N, size=self.K * batch_size) + else: + noise = np.random.choice( + self.N, self.K * batch_size, replace=True, p=self.Q.data) + noise = paddle.to_tensor(noise, dtype='int64', stop_gradient=False) + noise_idx = paddle.reshape(noise, [batch_size, self.K]) + return noise_idx + + def get_combined_idx(self, target_idx, noise_idx): + """Combined target and noise + """ + target_idx = paddle.reshape(target_idx, [-1, 1]) + return paddle.concat((target_idx, noise_idx), 1) + + def nce_loss(self, prob_model, prob_noise_in_model, prob_noise, + prob_target_in_noise): + """Combined the loss of target and noise + """ + + def safe_log(tensor): + """Safe log + """ + EPSILON = 1e-10 + return paddle.log(EPSILON + tensor) + + model_loss = safe_log(prob_model / + (prob_model + self.K * prob_target_in_noise)) + model_loss = paddle.reshape(model_loss, [-1]) + + noise_loss = paddle.sum( + safe_log((self.K * prob_noise) / + (prob_noise_in_model + self.K * prob_noise)), -1) + noise_loss = paddle.reshape(noise_loss, [-1]) + + loss = -(model_loss + noise_loss) + + return loss diff --git a/paddlespeech/vector/utils/vector_utils.py b/paddlespeech/vector/utils/vector_utils.py index 46de7ffa..dcf0f1aa 100644 --- a/paddlespeech/vector/utils/vector_utils.py +++ b/paddlespeech/vector/utils/vector_utils.py @@ -30,3 +30,11 @@ def get_chunks(seg_dur, audio_id, audio_duration): for i in range(num_chunks) ] return chunk_lst + + +def Q_from_tokens(token_num): + """Get prior model, data from uniform, would support others(guassian) in future + """ + freq = [1] * token_num + Q = paddle.to_tensor(freq, dtype='float64') + return Q / Q.sum() From 380afbbc5d828f81204a5b9ab9088d4491ba0b70 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 16:18:42 +0800 Subject: [PATCH 18/31] add ds2 model multi session, test=doc --- paddlespeech/server/conf/ws_application.yaml | 50 +--- .../server/conf/ws_conformer_application.yaml | 45 ++++ .../server/engine/asr/online/asr_engine.py | 224 ++++++++++++++++-- 3 files changed, 263 insertions(+), 56 deletions(-) create mode 100644 paddlespeech/server/conf/ws_conformer_application.yaml diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index aa3c208b..dae4a3ff 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -18,44 +18,10 @@ engine_list: ['asr_online'] # ENGINE CONFIG # ################################################################################# -# ################################### ASR ######################################### -# ################### speech task: asr; engine_type: online ####################### -# asr_online: -# model_type: 'deepspeech2online_aishell' -# am_model: # the pdmodel file of am static model [optional] -# am_params: # the pdiparams file of am static model [optional] -# lang: 'zh' -# sample_rate: 16000 -# cfg_path: -# decode_method: -# force_yes: True - -# am_predictor_conf: -# device: # set 'gpu:id' or 'cpu' -# switch_ir_optim: True -# glog_info: False # True -> print glog -# summary: True # False -> do not show predictor config - -# chunk_buffer_conf: -# frame_duration_ms: 80 -# shift_ms: 40 -# sample_rate: 16000 -# sample_width: 2 - -# vad_conf: -# aggressiveness: 2 -# sample_rate: 16000 -# frame_duration_ms: 20 -# sample_width: 2 -# padding_ms: 200 -# padding_ratio: 0.9 - - - ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### asr_online: - model_type: 'conformer2online_aishell' + model_type: 'deepspeech2online_aishell' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' @@ -71,9 +37,19 @@ asr_online: summary: True # False -> do not show predictor config chunk_buffer_conf: + frame_duration_ms: 80 + shift_ms: 40 + sample_rate: 16000 + sample_width: 2 window_n: 7 # frame shift_n: 4 # frame - window_ms: 25 # ms + window_ms: 20 # ms shift_ms: 10 # ms + + vad_conf: + aggressiveness: 2 sample_rate: 16000 - sample_width: 2 \ No newline at end of file + frame_duration_ms: 20 + sample_width: 2 + padding_ms: 200 + padding_ratio: 0.9 \ No newline at end of file diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml new file mode 100644 index 00000000..1a775f85 --- /dev/null +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -0,0 +1,45 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 0.0.0.0 +port: 8090 + +# The task format in the engin_list is: _ +# task choices = ['asr_online', 'tts_online'] +# protocol = ['websocket', 'http'] (only one can be selected). +# websocket only support online engine type. +protocol: 'websocket' +engine_list: ['asr_online'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### ASR ######################################### +################### speech task: asr; engine_type: online ####################### +asr_online: + model_type: 'conformer2online_aishell' + am_model: # the pdmodel file of am static model [optional] + am_params: # the pdiparams file of am static model [optional] + lang: 'zh' + sample_rate: 16000 + cfg_path: + decode_method: + force_yes: True + + am_predictor_conf: + device: # set 'gpu:id' or 'cpu' + switch_ir_optim: True + glog_info: False # True -> print glog + summary: True # False -> do not show predictor config + + chunk_buffer_conf: + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 25 # ms + shift_ms: 10 # ms + sample_rate: 16000 + sample_width: 2 \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index a8e25f4b..77eb5a21 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -13,7 +13,7 @@ # limitations under the License. import os from typing import Optional - +import copy import numpy as np import paddle from numpy import float32 @@ -93,7 +93,7 @@ class PaddleASRConnectionHanddler: ) self.config = asr_engine.config self.model_config = asr_engine.executor.config - self.model = asr_engine.executor.model + # self.model = asr_engine.executor.model self.asr_engine = asr_engine self.init() @@ -102,7 +102,32 @@ class PaddleASRConnectionHanddler: def init(self): self.model_type = self.asr_engine.executor.model_type if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: - pass + from paddlespeech.s2t.io.collator import SpeechCollator + self.sample_rate = self.asr_engine.executor.sample_rate + self.am_predictor = self.asr_engine.executor.am_predictor + self.text_feature = self.asr_engine.executor.text_feature + self.collate_fn_test = SpeechCollator.from_config(self.model_config) + self.decoder = CTCDecoder( + odim=self.model_config.output_dim, # is in vocab + enc_n_units=self.model_config.rnn_layer_size * 2, + blank_id=self.model_config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.model_config.get('ctc_grad_norm_type', None)) + + cfg = self.model_config.decode + decode_batch_size = 1 # for online + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + # frame window samples length and frame shift samples length + + self.win_length = int(self.model_config.window_ms * self.sample_rate) + self.n_shift = int(self.model_config.stride_ms * self.sample_rate) + elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: self.sample_rate = self.asr_engine.executor.sample_rate @@ -127,7 +152,65 @@ class PaddleASRConnectionHanddler: def extract_feat(self, samples): if "deepspeech2online" in self.model_type: - pass + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 + + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The connection remain the audio samples: {self.remained_wav.shape}" + ) + + # pcm16 -> pcm 32 + samples = pcm2float(self.remained_wav) + # read audio + speech_segment = SpeechSegment.from_pcm( + samples, self.sample_rate, transcript=" ") + # audio augment + self.collate_fn_test.augmentation.transform_audio(speech_segment) + + # extract speech feature + spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( + speech_segment, self.collate_fn_test.keep_transcription_text) + # CMVN spectrum + if self.collate_fn_test._normalizer: + spectrum = self.collate_fn_test._normalizer.apply(spectrum) + + # spectrum augment + audio = self.collate_fn_test.augmentation.transform_feature( + spectrum) + + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + # audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + + if self.cached_feat is None: + self.cached_feat = audio + else: + assert (len(audio.shape) == 3) + assert (len(self.cached_feat.shape) == 3) + self.cached_feat = paddle.concat( + [self.cached_feat, audio], axis=1) + + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + + self.num_frames += audio_len + self.remained_wav = self.remained_wav[self.n_shift * audio_len:] + + logger.info( + f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" + ) elif "conformer2online" in self.model_type: logger.info("Online ASR extract the feat") samples = np.frombuffer(samples, dtype=np.int16) @@ -179,24 +262,81 @@ class PaddleASRConnectionHanddler: # logger.info(f"accumulate samples: {self.num_samples}") def reset(self): - self.subsampling_cache = None - self.elayers_output_cache = None - self.conformer_cnn_cache = None - self.encoder_out = None - self.cached_feat = None - self.remained_wav = None - self.offset = 0 - self.num_samples = 0 - self.device = None - self.hyps = [] - self.num_frames = 0 - self.chunk_num = 0 - self.global_frame_offset = 0 - self.result_transcripts = [''] + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + # for deepspeech2 + self.chunk_state_h_box = copy.deepcopy(self.asr_engine.executor.chunk_state_h_box) + self.chunk_state_c_box = copy.deepcopy(self.asr_engine.executor.chunk_state_c_box) + self.decoder.reset_decoder(batch_size=1) + elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: + # for conformer online + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_out = None + self.cached_feat = None + self.remained_wav = None + self.offset = 0 + self.num_samples = 0 + self.device = None + self.hyps = [] + self.num_frames = 0 + self.chunk_num = 0 + self.global_frame_offset = 0 + self.result_transcripts = [''] def decode(self, is_finished=False): if "deepspeech2online" in self.model_type: - pass + # x_chunk 是特征数据 + decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model + context = 7 # context=7 in deepspeech2 model + subsampling = 4 # subsampling=4 in deepspeech2 model + stride = subsampling * decoding_chunk_size + cached_feature_num = context - subsampling + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + + if self.cached_feat is None: + logger.info("no audio feat, please input more pcm data") + return + + num_frames = self.cached_feat.shape[1] + logger.info( + f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" + ) + # the cached feat must be larger decoding_window + if num_frames < decoding_window and not is_finished: + logger.info( + f"frame feat num is less than {decoding_window}, please input more pcm data" + ) + return None, None + + # if is_finished=True, we need at least context frames + if num_frames < context: + logger.info( + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + ) + return None, None + logger.info("start to do model forward") + # num_frames - context + 1 ensure that current frame can get context window + if is_finished: + # if get the finished chunk, we need process the last context + left_frames = context + else: + # we only process decoding_window frames for one chunk + left_frames = decoding_window + + for cur in range(0, num_frames - left_frames + 1, stride): + end = min(cur + decoding_window, num_frames) + # extract the audio + x_chunk = self.cached_feat[:, cur:end, :].numpy() + x_chunk_lens = np.array([x_chunk.shape[1]]) + trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) + + self.result_transcripts = [trans_best] + + self.cached_feat = self.cached_feat[:, end - + cached_feature_num:, :] + # return trans_best[0] elif "conformer" in self.model_type or "transformer" in self.model_type: try: logger.info( @@ -210,6 +350,48 @@ class PaddleASRConnectionHanddler: else: raise Exception("invalid model name") + def decode_one_chunk(self, x_chunk, x_chunk_lens): + logger.info("start to decoce one chunk with deepspeech2 model") + input_names = self.am_predictor.get_input_names() + audio_handle = self.am_predictor.get_input_handle(input_names[0]) + audio_len_handle = self.am_predictor.get_input_handle( + input_names[1]) + h_box_handle = self.am_predictor.get_input_handle(input_names[2]) + c_box_handle = self.am_predictor.get_input_handle(input_names[3]) + + audio_handle.reshape(x_chunk.shape) + audio_handle.copy_from_cpu(x_chunk) + + audio_len_handle.reshape(x_chunk_lens.shape) + audio_len_handle.copy_from_cpu(x_chunk_lens) + + h_box_handle.reshape(self.chunk_state_h_box.shape) + h_box_handle.copy_from_cpu(self.chunk_state_h_box) + + c_box_handle.reshape(self.chunk_state_c_box.shape) + c_box_handle.copy_from_cpu(self.chunk_state_c_box) + + output_names = self.am_predictor.get_output_names() + output_handle = self.am_predictor.get_output_handle(output_names[0]) + output_lens_handle = self.am_predictor.get_output_handle( + output_names[1]) + output_state_h_handle = self.am_predictor.get_output_handle( + output_names[2]) + output_state_c_handle = self.am_predictor.get_output_handle( + output_names[3]) + + self.am_predictor.run() + + output_chunk_probs = output_handle.copy_to_cpu() + output_chunk_lens = output_lens_handle.copy_to_cpu() + self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() + self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() + + self.decoder.next(output_chunk_probs, output_chunk_lens) + trans_best, trans_beam = self.decoder.decode() + logger.info(f"decode one one best result: {trans_best[0]}") + return trans_best[0] + def advance_decoding(self, is_finished=False): logger.info("start to decode with advanced_decoding method") cfg = self.ctc_decode_config @@ -240,6 +422,7 @@ class PaddleASRConnectionHanddler: ) return None, None + # if is_finished=True, we need at least context frames if num_frames < context: logger.info( "flast {num_frames} is less than context {context} frames, and we cannot do model forward" @@ -315,6 +498,9 @@ class PaddleASRConnectionHanddler: return '' def rescoring(self): + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + return + logger.info("rescoring the final result") if "attention_rescoring" != self.ctc_decode_config.decoding_method: return From f56dba0ca7da29aa5ad11f5ad83e4ee62f1a2fa4 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 17:57:50 +0800 Subject: [PATCH 19/31] fix the code format, test=doc --- paddlespeech/cli/asr/infer.py | 2 +- .../server/conf/ws_conformer_application.yaml | 2 +- .../server/engine/asr/online/asr_engine.py | 123 +++++++++--------- 3 files changed, 64 insertions(+), 63 deletions(-) diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 53f71a70..f1e46ca1 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -129,7 +129,7 @@ model_alias = { "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "conformer": "paddlespeech.s2t.models.u2:U2Model", - "conformer2online": + "conformer_online": "paddlespeech.s2t.models.u2:U2Model", "transformer": "paddlespeech.s2t.models.u2:U2Model", diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml index 1a775f85..89a861ef 100644 --- a/paddlespeech/server/conf/ws_conformer_application.yaml +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -21,7 +21,7 @@ engine_list: ['asr_online'] ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### asr_online: - model_type: 'conformer2online_aishell' + model_type: 'conformer_online_multi-cn' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 77eb5a21..3c2b066c 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import os from typing import Optional -import copy + import numpy as np import paddle from numpy import float32 @@ -58,7 +59,7 @@ pretrained_models = { 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' }, - "conformer2online_aishell-zh-16k": { + "conformer_online_multi-cn-zh-16k": { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', 'md5': @@ -93,19 +94,22 @@ class PaddleASRConnectionHanddler: ) self.config = asr_engine.config self.model_config = asr_engine.executor.config - # self.model = asr_engine.executor.model self.asr_engine = asr_engine self.init() self.reset() def init(self): + # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer self.model_type = self.asr_engine.executor.model_type + self.sample_rate = self.asr_engine.executor.sample_rate + # tokens to text + self.text_feature = self.asr_engine.executor.text_feature + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: from paddlespeech.s2t.io.collator import SpeechCollator - self.sample_rate = self.asr_engine.executor.sample_rate self.am_predictor = self.asr_engine.executor.am_predictor - self.text_feature = self.asr_engine.executor.text_feature + self.collate_fn_test = SpeechCollator.from_config(self.model_config) self.decoder = CTCDecoder( odim=self.model_config.output_dim, # is in vocab @@ -114,7 +118,8 @@ class PaddleASRConnectionHanddler: dropout_rate=0.0, reduction=True, # sum batch_average=True, # sum / batch_size - grad_norm_type=self.model_config.get('ctc_grad_norm_type', None)) + grad_norm_type=self.model_config.get('ctc_grad_norm_type', + None)) cfg = self.model_config.decode decode_batch_size = 1 # for online @@ -123,20 +128,16 @@ class PaddleASRConnectionHanddler: cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) - # frame window samples length and frame shift samples length - - self.win_length = int(self.model_config.window_ms * self.sample_rate) - self.n_shift = int(self.model_config.stride_ms * self.sample_rate) + # frame window samples length and frame shift samples length - elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: - self.sample_rate = self.asr_engine.executor.sample_rate + self.win_length = int(self.model_config.window_ms * + self.sample_rate) + self.n_shift = int(self.model_config.stride_ms * self.sample_rate) + elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model self.model = self.asr_engine.executor.model - # tokens to text - self.text_feature = self.asr_engine.executor.text_feature - # ctc decoding config self.ctc_decode_config = self.asr_engine.executor.config.decode self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) @@ -189,7 +190,7 @@ class PaddleASRConnectionHanddler: audio = paddle.to_tensor(audio, dtype='float32') # audio_len = paddle.to_tensor(audio_len) audio = paddle.unsqueeze(audio, axis=0) - + if self.cached_feat is None: self.cached_feat = audio else: @@ -211,7 +212,7 @@ class PaddleASRConnectionHanddler: logger.info( f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" ) - elif "conformer2online" in self.model_type: + elif "conformer_online" in self.model_type: logger.info("Online ASR extract the feat") samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 @@ -264,41 +265,43 @@ class PaddleASRConnectionHanddler: def reset(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: # for deepspeech2 - self.chunk_state_h_box = copy.deepcopy(self.asr_engine.executor.chunk_state_h_box) - self.chunk_state_c_box = copy.deepcopy(self.asr_engine.executor.chunk_state_c_box) + self.chunk_state_h_box = copy.deepcopy( + self.asr_engine.executor.chunk_state_h_box) + self.chunk_state_c_box = copy.deepcopy( + self.asr_engine.executor.chunk_state_c_box) self.decoder.reset_decoder(batch_size=1) - elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: - # for conformer online - self.subsampling_cache = None - self.elayers_output_cache = None - self.conformer_cnn_cache = None - self.encoder_out = None - self.cached_feat = None - self.remained_wav = None - self.offset = 0 - self.num_samples = 0 - self.device = None - self.hyps = [] - self.num_frames = 0 - self.chunk_num = 0 - self.global_frame_offset = 0 - self.result_transcripts = [''] + + # for conformer online + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_out = None + self.cached_feat = None + self.remained_wav = None + self.offset = 0 + self.num_samples = 0 + self.device = None + self.hyps = [] + self.num_frames = 0 + self.chunk_num = 0 + self.global_frame_offset = 0 + self.result_transcripts = [''] def decode(self, is_finished=False): if "deepspeech2online" in self.model_type: # x_chunk 是特征数据 - decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model - context = 7 # context=7 in deepspeech2 model - subsampling = 4 # subsampling=4 in deepspeech2 model + decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model + context = 7 # context=7 in deepspeech2 model + subsampling = 4 # subsampling=4 in deepspeech2 model stride = subsampling * decoding_chunk_size cached_feature_num = context - subsampling # decoding window for model - decoding_window = (decoding_chunk_size - 1) * subsampling + context - + decoding_window = (decoding_chunk_size - 1) * subsampling + context + if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") - return - + return + num_frames = self.cached_feat.shape[1] logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" @@ -306,14 +309,14 @@ class PaddleASRConnectionHanddler: # the cached feat must be larger decoding_window if num_frames < decoding_window and not is_finished: logger.info( - f"frame feat num is less than {decoding_window}, please input more pcm data" + f"frame feat num is less than {decoding_window}, please input more pcm data" ) return None, None # if is_finished=True, we need at least context frames if num_frames < context: logger.info( - "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" ) return None, None logger.info("start to do model forward") @@ -334,8 +337,7 @@ class PaddleASRConnectionHanddler: self.result_transcripts = [trans_best] - self.cached_feat = self.cached_feat[:, end - - cached_feature_num:, :] + self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] # return trans_best[0] elif "conformer" in self.model_type or "transformer" in self.model_type: try: @@ -354,8 +356,7 @@ class PaddleASRConnectionHanddler: logger.info("start to decoce one chunk with deepspeech2 model") input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) - audio_len_handle = self.am_predictor.get_input_handle( - input_names[1]) + audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) h_box_handle = self.am_predictor.get_input_handle(input_names[2]) c_box_handle = self.am_predictor.get_input_handle(input_names[3]) @@ -374,11 +375,11 @@ class PaddleASRConnectionHanddler: output_names = self.am_predictor.get_output_names() output_handle = self.am_predictor.get_output_handle(output_names[0]) output_lens_handle = self.am_predictor.get_output_handle( - output_names[1]) + output_names[1]) output_state_h_handle = self.am_predictor.get_output_handle( - output_names[2]) + output_names[2]) output_state_c_handle = self.am_predictor.get_output_handle( - output_names[3]) + output_names[3]) self.am_predictor.run() @@ -389,7 +390,7 @@ class PaddleASRConnectionHanddler: self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() - logger.info(f"decode one one best result: {trans_best[0]}") + logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] def advance_decoding(self, is_finished=False): @@ -500,7 +501,7 @@ class PaddleASRConnectionHanddler: def rescoring(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: return - + logger.info("rescoring the final result") if "attention_rescoring" != self.ctc_decode_config.decoding_method: return @@ -587,7 +588,7 @@ class ASRServerExecutor(ASRExecutor): return decompressed_path def _init_from_path(self, - model_type: str='wenetspeech', + model_type: str='deepspeech2online_aishell', am_model: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None, lang: str='zh', @@ -647,7 +648,7 @@ class ASRServerExecutor(ASRExecutor): self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) - elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + elif "conformer" in model_type or "transformer" in model_type: logger.info("start to create the stream conformer asr engine") if self.config.spm_model_prefix: self.config.spm_model_prefix = os.path.join( @@ -711,7 +712,7 @@ class ASRServerExecutor(ASRExecutor): self.chunk_state_c_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) - elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: + elif "conformer" in model_type or "transformer" in model_type: model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} logger.info(f"model name: {model_name}") @@ -742,7 +743,7 @@ class ASRServerExecutor(ASRExecutor): self.chunk_state_c_box = np.zeros( (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), dtype=float32) - elif "conformer" in self.model_type or "transformer" in self.model_type or "wenetspeech" in self.model_type: + elif "conformer" in self.model_type or "transformer" in self.model_type: self.transformer_decode_reset() def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): @@ -754,7 +755,7 @@ class ASRServerExecutor(ASRExecutor): model_type (str): online model type Returns: - [type]: [description] + str: one best result """ logger.info("start to decoce chunk by chunk") if "deepspeech2online" in model_type: @@ -795,7 +796,7 @@ class ASRServerExecutor(ASRExecutor): self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() - logger.info(f"decode one one best result: {trans_best[0]}") + logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] elif "conformer" in model_type or "transformer" in model_type: @@ -972,7 +973,7 @@ class ASRServerExecutor(ASRExecutor): x_chunk_lens = np.array([audio_len]) return x_chunk, x_chunk_lens - elif "conformer2online" in self.model_type: + elif "conformer_online" in self.model_type: if sample_rate != self.sample_rate: logger.info(f"audio sample rate {sample_rate} is not match," @@ -1005,7 +1006,7 @@ class ASREngine(BaseEngine): def __init__(self): super(ASREngine, self).__init__() - logger.info("create the online asr engine instache") + logger.info("create the online asr engine instance") def init(self, config: dict) -> bool: """init engine resource From babac27a7943b5be254afab8af09e909b0d3151c Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 18:14:30 +0800 Subject: [PATCH 20/31] fix ds2 online edge bug, test=doc --- paddlespeech/cli/asr/pretrained_models.py | 2 ++ .../server/engine/asr/online/asr_engine.py | 20 +++++++++++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/paddlespeech/cli/asr/pretrained_models.py b/paddlespeech/cli/asr/pretrained_models.py index a16c4750..cc52c751 100644 --- a/paddlespeech/cli/asr/pretrained_models.py +++ b/paddlespeech/cli/asr/pretrained_models.py @@ -88,6 +88,8 @@ model_alias = { "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "conformer": "paddlespeech.s2t.models.u2:U2Model", + "conformer_online": + "paddlespeech.s2t.models.u2:U2Model", "transformer": "paddlespeech.s2t.models.u2:U2Model", "wenetspeech": diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 3c2b066c..4d15d93b 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -130,9 +130,10 @@ class PaddleASRConnectionHanddler: cfg.num_proc_bsearch) # frame window samples length and frame shift samples length - self.win_length = int(self.model_config.window_ms * + self.win_length = int(self.model_config.window_ms / 1000 * self.sample_rate) - self.n_shift = int(self.model_config.stride_ms * self.sample_rate) + self.n_shift = int(self.model_config.stride_ms / 1000 * + self.sample_rate) elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model @@ -158,6 +159,11 @@ class PaddleASRConnectionHanddler: samples = np.frombuffer(samples, dtype=np.int16) assert samples.ndim == 1 + # pcm16 -> pcm 32 + # pcm2float will change the orignal samples, + # so we shoule do pcm2float before concatenate + samples = pcm2float(samples) + if self.remained_wav is None: self.remained_wav = samples else: @@ -167,11 +173,9 @@ class PaddleASRConnectionHanddler: f"The connection remain the audio samples: {self.remained_wav.shape}" ) - # pcm16 -> pcm 32 - samples = pcm2float(self.remained_wav) # read audio speech_segment = SpeechSegment.from_pcm( - samples, self.sample_rate, transcript=" ") + self.remained_wav, self.sample_rate, transcript=" ") # audio augment self.collate_fn_test.augmentation.transform_audio(speech_segment) @@ -474,6 +478,7 @@ class PaddleASRConnectionHanddler: self.hyps = self.searcher.get_one_best_hyps() assert self.cached_feat.shape[0] == 1 assert end >= cached_feature_num + self.cached_feat = self.cached_feat[0, end - cached_feature_num:, :].unsqueeze(0) assert len( @@ -515,7 +520,6 @@ class PaddleASRConnectionHanddler: return # assert len(hyps) == beam_size - paddle.save(self.encoder_out, "encoder.out") hyp_list = [] for hyp in hyps: hyp_content = hyp[0] @@ -815,7 +819,7 @@ class ASRServerExecutor(ASRExecutor): def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): logger.info("start to decode with advanced_decoding method") - encoder_out, encoder_mask = self.decode_forward(xs) + encoder_out, encoder_mask = self.encoder_forward(xs) ctc_probs = self.model.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) @@ -827,7 +831,7 @@ class ASRServerExecutor(ASRExecutor): if "attention_rescoring" in self.config.decode.decoding_method: self.rescoring(encoder_out, xs.place) - def decode_forward(self, xs): + def encoder_forward(self, xs): logger.info("get the model out from the feat") cfg = self.config.decode decoding_chunk_size = cfg.decoding_chunk_size From 1133540682fef94e2baa5eef9288bfa10c82f57c Mon Sep 17 00:00:00 2001 From: Yang Zhou Date: Tue, 19 Apr 2022 20:10:32 +0800 Subject: [PATCH 21/31] add websocket --- speechx/CMakeLists.txt | 5 +- speechx/examples/ds2_ol/CMakeLists.txt | 3 +- speechx/examples/ds2_ol/aishell/path.sh | 4 +- speechx/examples/ds2_ol/aishell/run.sh | 12 +-- .../examples/ds2_ol/decoder/CMakeLists.txt | 3 + .../ctc-prefix-beam-search-decoder-ol.cc | 20 ++-- .../examples/ds2_ol/feat/cmvn-json2kaldi.cc | 4 +- .../feat/linear-spectrogram-wo-db-norm-ol.cc | 4 +- .../examples/ds2_ol/websocket/CMakeLists.txt | 10 ++ .../ds2_ol/websocket/websocket_client_main.cc | 82 ++++++++++++++++ .../ds2_ol/websocket/websocket_server_main.cc | 30 ++++++ speechx/speechx/CMakeLists.txt | 8 +- speechx/speechx/base/common.h | 2 + speechx/speechx/decoder/CMakeLists.txt | 3 +- speechx/speechx/decoder/ctc_tlg_decoder.cc | 3 +- speechx/speechx/decoder/param.h | 94 +++++++++++++++++++ speechx/speechx/decoder/recognizer.cc | 60 ++++++++++++ speechx/speechx/decoder/recognizer.h | 59 ++++++++++++ speechx/speechx/frontend/audio/CMakeLists.txt | 3 +- speechx/speechx/frontend/audio/audio_cache.cc | 2 +- speechx/speechx/frontend/audio/audio_cache.h | 2 +- .../speechx/frontend/audio/feature_cache.cc | 62 +++++++----- .../speechx/frontend/audio/feature_cache.h | 21 ++++- .../frontend/audio/feature_pipeline.cc | 36 +++++++ .../speechx/frontend/audio/feature_pipeline.h | 57 +++++++++++ .../frontend/audio/linear_spectrogram.cc | 8 +- .../frontend/audio/linear_spectrogram.h | 8 +- speechx/speechx/nnet/decodable.cc | 1 - 28 files changed, 537 insertions(+), 69 deletions(-) create mode 100644 speechx/examples/ds2_ol/websocket/CMakeLists.txt create mode 100644 speechx/examples/ds2_ol/websocket/websocket_client_main.cc create mode 100644 speechx/examples/ds2_ol/websocket/websocket_server_main.cc create mode 100644 speechx/speechx/decoder/param.h create mode 100644 speechx/speechx/decoder/recognizer.cc create mode 100644 speechx/speechx/decoder/recognizer.h create mode 100644 speechx/speechx/frontend/audio/feature_pipeline.cc create mode 100644 speechx/speechx/frontend/audio/feature_pipeline.h diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index f1330d1d..98d9e637 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -63,7 +63,8 @@ include(libsndfile) # include(boost) # not work set(boost_SOURCE_DIR ${fc_patch}/boost-src) set(BOOST_ROOT ${boost_SOURCE_DIR}) -# #find_package(boost REQUIRED PATHS ${BOOST_ROOT}) +include_directories(${boost_SOURCE_DIR}) +link_directories(${boost_SOURCE_DIR}/stage/lib) # Eigen include(eigen) @@ -141,4 +142,4 @@ set(DEPS ${DEPS} set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx) add_subdirectory(speechx) -add_subdirectory(examples) \ No newline at end of file +add_subdirectory(examples) diff --git a/speechx/examples/ds2_ol/CMakeLists.txt b/speechx/examples/ds2_ol/CMakeLists.txt index 89cbd0ef..08c19484 100644 --- a/speechx/examples/ds2_ol/CMakeLists.txt +++ b/speechx/examples/ds2_ol/CMakeLists.txt @@ -2,4 +2,5 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) add_subdirectory(feat) add_subdirectory(nnet) -add_subdirectory(decoder) \ No newline at end of file +add_subdirectory(decoder) +add_subdirectory(websocket) diff --git a/speechx/examples/ds2_ol/aishell/path.sh b/speechx/examples/ds2_ol/aishell/path.sh index 8e26e6e7..520129ea 100644 --- a/speechx/examples/ds2_ol/aishell/path.sh +++ b/speechx/examples/ds2_ol/aishell/path.sh @@ -1,6 +1,6 @@ # This contains the locations of binarys build required for running the examples. -SPEECHX_ROOT=$PWD/../../../ +SPEECHX_ROOT=$PWD/../../.. SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_TOOLS=$SPEECHX_ROOT/tools @@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin export LC_AL=C -SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat +SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat:$SPEECHX_EXAMPLES/ds2_ol/websocket export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/aishell/run.sh b/speechx/examples/ds2_ol/aishell/run.sh index 3a1c19ee..0719ba14 100755 --- a/speechx/examples/ds2_ol/aishell/run.sh +++ b/speechx/examples/ds2_ol/aishell/run.sh @@ -42,7 +42,7 @@ fi if [ ! -d $ckpt_dir ]; then mkdir -p $ckpt_dir wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz - tar xzfv $model_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir + tar xzfv $ckpt_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir fi lm=$data/zh_giga.no_cna_cmn.prune01244.klm @@ -79,7 +79,7 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \ ctc-prefix-beam-search-decoder-ol \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ + --params_path=$model_dir/avg_1.jit.pdiparams \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --dict_file=$vocb_dir/vocab.txt \ --result_wspecifier=ark,t:$data/split${nj}/JOB/result @@ -92,7 +92,7 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \ ctc-prefix-beam-search-decoder-ol \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ + --params_path=$model_dir/avg_1.jit.pdiparams \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --dict_file=$vocb_dir/vocab.txt \ --lm_path=$lm \ @@ -104,9 +104,9 @@ utils/compute-wer.py --char=1 --v=1 ${label_file}_lm $text > ${wer}_lm graph_dir=./aishell_graph -if [ ! -d $ ]; then +if [ ! -d $graph_dir ]; then wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip - unzip -d aishell_graph.zip + unzip aishell_graph.zip fi @@ -115,7 +115,7 @@ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \ wfst-decoder-ol \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ + --params_path=$model_dir/avg_1.jit.pdiparams \ --word_symbol_table=$graph_dir/words.txt \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --graph_path=$graph_dir/TLG.fst --max_active=7500 \ diff --git a/speechx/examples/ds2_ol/decoder/CMakeLists.txt b/speechx/examples/ds2_ol/decoder/CMakeLists.txt index 6139ebfa..62dd6862 100644 --- a/speechx/examples/ds2_ol/decoder/CMakeLists.txt +++ b/speechx/examples/ds2_ol/decoder/CMakeLists.txt @@ -17,3 +17,6 @@ add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) +add_executable(recognizer_test_main ${CMAKE_CURRENT_SOURCE_DIR}/recognizer_test_main.cc) +target_include_directories(recognizer_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(recognizer_test_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS}) diff --git a/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc b/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc index 49d64b69..21afec27 100644 --- a/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc +++ b/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc @@ -34,12 +34,10 @@ DEFINE_int32(receptive_field_length, DEFINE_int32(downsampling_rate, 4, "two CNN(kernel=5) module downsampling rate."); -DEFINE_string( - model_input_names, - "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", - "model input names"); DEFINE_string(model_output_names, - "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", + "save_infer_model/scale_0.tmp_1,save_infer_model/" + "scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/" + "scale_3.tmp_1", "model output names"); DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); @@ -52,18 +50,14 @@ int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); - CHECK(FLAGS_result_wspecifier != ""); - CHECK(FLAGS_feature_rspecifier != ""); - kaldi::SequentialBaseFloatMatrixReader feature_reader( FLAGS_feature_rspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - - std::string model_graph = FLAGS_model_path; + std::string model_path = FLAGS_model_path; std::string model_params = FLAGS_param_path; std::string dict_file = FLAGS_dict_file; std::string lm_path = FLAGS_lm_path; - LOG(INFO) << "model path: " << model_graph; + LOG(INFO) << "model path: " << model_path; LOG(INFO) << "model param: " << model_params; LOG(INFO) << "dict path: " << dict_file; LOG(INFO) << "lm path: " << lm_path; @@ -76,10 +70,9 @@ int main(int argc, char* argv[]) { ppspeech::CTCBeamSearch decoder(opts); ppspeech::ModelOptions model_opts; - model_opts.model_path = model_graph; + model_opts.model_path = model_path; model_opts.params_path = model_params; model_opts.cache_shape = FLAGS_model_cache_names; - model_opts.input_names = FLAGS_model_input_names; model_opts.output_names = FLAGS_model_output_names; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); @@ -125,7 +118,6 @@ int main(int argc, char* argv[]) { if (feature_chunk_size < receptive_field_length) break; int32 start = chunk_idx * chunk_stride; - int32 end = start + chunk_size; for (int row_id = 0; row_id < chunk_size; ++row_id) { kaldi::SubVector tmp(feature, start); diff --git a/speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc b/speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc index b8385664..0a9cfb06 100644 --- a/speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc +++ b/speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc @@ -73,9 +73,9 @@ int main(int argc, char* argv[]) { LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path; LOG(INFO) << "Binary: " << FLAGS_binary; } catch (simdjson::simdjson_error& err) { - LOG(ERR) << err.what(); + LOG(ERROR) << err.what(); } return 0; -} \ No newline at end of file +} diff --git a/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc b/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc index 27ca6f9f..0d10bd30 100644 --- a/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc +++ b/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc @@ -32,7 +32,6 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); - int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -66,7 +65,8 @@ int main(int argc, char* argv[]) { std::unique_ptr cmvn( new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram))); - ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); + ppspeech::FeatureCacheOptions feat_cache_opts; + ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); LOG(INFO) << "feat dim: " << feature_cache.Dim(); int sample_rate = 16000; diff --git a/speechx/examples/ds2_ol/websocket/CMakeLists.txt b/speechx/examples/ds2_ol/websocket/CMakeLists.txt new file mode 100644 index 00000000..754b528e --- /dev/null +++ b/speechx/examples/ds2_ol/websocket/CMakeLists.txt @@ -0,0 +1,10 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc) +target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(websocket_server_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS}) + +add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc) +target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(websocket_client_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS}) + diff --git a/speechx/examples/ds2_ol/websocket/websocket_client_main.cc b/speechx/examples/ds2_ol/websocket/websocket_client_main.cc new file mode 100644 index 00000000..68ea898a --- /dev/null +++ b/speechx/examples/ds2_ol/websocket/websocket_client_main.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "websocket/websocket_client.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/kaldi-io.h" +#include "kaldi/util/table-types.h" + +DEFINE_string(host, "127.0.0.1", "host of websocket server"); +DEFINE_int32(port, 201314, "port of websocket server"); +DEFINE_string(wav_rspecifier, "", "test wav scp path"); +DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size"); + +using kaldi::int16; +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + ppspeech::WebSocketClient client(FLAGS_host, FLAGS_port); + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + + const int sample_rate = 16000; + const float streaming_chunk = FLAGS_streaming_chunk; + const int chunk_sample_size = streaming_chunk * sample_rate; + + for (; !wav_reader.Done(); wav_reader.Next()) { + client.SendStartSignal(); + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + CHECK_EQ(wave_data.SampFreq(), sample_rate); + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + const int tot_samples = waveform.Dim(); + int sample_offset = 0; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = static_cast(waveform(sample_offset + i)); + } + client.SendBinaryData(wav_chunk.data(), + wav_chunk.size() * sizeof(int16)); + + + sample_offset += cur_chunk_size; + LOG(INFO) << "Send " << cur_chunk_size << " samples"; + std::this_thread::sleep_for( + std::chrono::milliseconds(static_cast(1 * 1000))); + + if (cur_chunk_size < chunk_sample_size) { + client.SendEndSignal(); + } + } + + while (!client.Done()) { + } + std::string result = client.GetResult(); + LOG(INFO) << "utt: " << utt << " " << result; + + + client.Join(); + return 0; + } + return 0; +} diff --git a/speechx/examples/ds2_ol/websocket/websocket_server_main.cc b/speechx/examples/ds2_ol/websocket/websocket_server_main.cc new file mode 100644 index 00000000..43cbd6bb --- /dev/null +++ b/speechx/examples/ds2_ol/websocket/websocket_server_main.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "websocket/websocket_server.h" +#include "decoder/param.h" + +DEFINE_int32(port, 201314, "websocket listening port"); + +int main(int argc, char *argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure(); + + ppspeech::WebSocketServer server(FLAGS_port, resource); + LOG(INFO) << "Listening at port " << FLAGS_port; + server.Start(); + return 0; +} diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index 225abee7..b4da095d 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -30,4 +30,10 @@ include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/decoder ) -add_subdirectory(decoder) \ No newline at end of file +add_subdirectory(decoder) + +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/websocket +) +add_subdirectory(websocket) diff --git a/speechx/speechx/base/common.h b/speechx/speechx/base/common.h index 7502bc5e..a9303cbb 100644 --- a/speechx/speechx/base/common.h +++ b/speechx/speechx/base/common.h @@ -28,8 +28,10 @@ #include #include #include +#include #include #include +#include #include #include "base/basic_types.h" diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index ee0863fd..06bf4020 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -7,5 +7,6 @@ add_library(decoder STATIC ctc_decoders/path_trie.cpp ctc_decoders/scorer.cpp ctc_tlg_decoder.cc + recognizer.cc ) -target_link_libraries(decoder PUBLIC kenlm utils fst) +target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc index 5365e709..7b720e7b 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -33,7 +33,6 @@ void TLGDecoder::InitDecoder() { void TLGDecoder::AdvanceDecode( const std::shared_ptr& decodable) { while (!decodable->IsLastFrame(frame_decoded_size_)) { - LOG(INFO) << "num frame decode: " << frame_decoded_size_; AdvanceDecoding(decodable.get()); } } @@ -63,4 +62,4 @@ std::string TLGDecoder::GetFinalBestPath() { } return words; } -} \ No newline at end of file +} diff --git a/speechx/speechx/decoder/param.h b/speechx/speechx/decoder/param.h new file mode 100644 index 00000000..cd50ef53 --- /dev/null +++ b/speechx/speechx/decoder/param.h @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" + +#include "decoder/ctc_beam_search_decoder.h" +#include "decoder/ctc_tlg_decoder.h" +#include "frontend/audio/feature_pipeline.h" + +DEFINE_string(cmvn_file, "", "read cmvn"); +DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size"); +DEFINE_bool(convert2PCM32, true, "audio convert to pcm32"); +DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); +DEFINE_string(params_path, "avg_1.jit.pdiparams", "paddle nnet model param"); +DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); +DEFINE_string(graph_path, "TLG", "decoder graph"); +DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); +DEFINE_int32(max_active, 7500, "max active"); +DEFINE_double(beam, 15.0, "decoder beam"); +DEFINE_double(lattice_beam, 7.5, "decoder beam"); +DEFINE_int32(receptive_field_length, + 7, + "receptive field of two CNN(kernel=5) downsampling module."); +DEFINE_int32(downsampling_rate, + 4, + "two CNN(kernel=5) module downsampling rate."); +DEFINE_string(model_output_names, + "save_infer_model/scale_0.tmp_1,save_infer_model/" + "scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/" + "scale_3.tmp_1", + "model output names"); +DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); + +namespace ppspeech { +// todo refactor later +FeaturePipelineOptions InitFeaturePipelineOptions() { + FeaturePipelineOptions opts; + opts.cmvn_file = FLAGS_cmvn_file; + opts.linear_spectrogram_opts.streaming_chunk = FLAGS_streaming_chunk; + opts.convert2PCM32 = FLAGS_convert2PCM32; + kaldi::FrameExtractionOptions frame_opts; + frame_opts.frame_length_ms = 20; + frame_opts.frame_shift_ms = 10; + frame_opts.remove_dc_offset = false; + frame_opts.window_type = "hanning"; + frame_opts.preemph_coeff = 0.0; + frame_opts.dither = 0.0; + opts.linear_spectrogram_opts.frame_opts = frame_opts; + opts.feature_cache_opts.frame_chunk_size = FLAGS_receptive_field_length; + opts.feature_cache_opts.frame_chunk_stride = FLAGS_downsampling_rate; + return opts; +} + +ModelOptions InitModelOptions() { + ModelOptions model_opts; + model_opts.model_path = FLAGS_model_path; + model_opts.params_path = FLAGS_params_path; + model_opts.cache_shape = FLAGS_model_cache_names; + model_opts.output_names = FLAGS_model_output_names; + return model_opts; +} + +TLGDecoderOptions InitDecoderOptions() { + TLGDecoderOptions decoder_opts; + decoder_opts.word_symbol_table = FLAGS_word_symbol_table; + decoder_opts.fst_path = FLAGS_graph_path; + decoder_opts.opts.max_active = FLAGS_max_active; + decoder_opts.opts.beam = FLAGS_beam; + decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; + return decoder_opts; +} + +RecognizerResource InitRecognizerResoure() { + RecognizerResource resource; + resource.acoustic_scale = FLAGS_acoustic_scale; + resource.feature_pipeline_opts = InitFeaturePipelineOptions(); + resource.model_opts = InitModelOptions(); + resource.tlg_opts = InitDecoderOptions(); + return resource; +} +} \ No newline at end of file diff --git a/speechx/speechx/decoder/recognizer.cc b/speechx/speechx/decoder/recognizer.cc new file mode 100644 index 00000000..2c90ada9 --- /dev/null +++ b/speechx/speechx/decoder/recognizer.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/recognizer.h" + +namespace ppspeech { + +using kaldi::Vector; +using kaldi::VectorBase; +using kaldi::BaseFloat; +using std::vector; +using kaldi::SubVector; +using std::unique_ptr; + +Recognizer::Recognizer(const RecognizerResource& resource) { + // resource_ = resource; + const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; + feature_pipeline_.reset(new FeaturePipeline(feature_opts)); + std::shared_ptr nnet(new PaddleNnet(resource.model_opts)); + BaseFloat ac_scale = resource.acoustic_scale; + decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale)); + decoder_.reset(new TLGDecoder(resource.tlg_opts)); + input_finished_ = false; +} + +void Recognizer::Accept(const Vector& waves) { + feature_pipeline_->Accept(waves); +} + +void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); } + +std::string Recognizer::GetFinalResult() { + return decoder_->GetFinalBestPath(); +} + +void Recognizer::SetFinished() { + feature_pipeline_->SetFinished(); + input_finished_ = true; +} + +bool Recognizer::IsFinished() { return input_finished_; } + +void Recognizer::Reset() { + feature_pipeline_->Reset(); + decodable_->Reset(); + decoder_->Reset(); +} + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/recognizer.h b/speechx/speechx/decoder/recognizer.h new file mode 100644 index 00000000..9a7e7d11 --- /dev/null +++ b/speechx/speechx/decoder/recognizer.h @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor later (SGoat) + +#pragma once + +#include "decoder/ctc_beam_search_decoder.h" +#include "decoder/ctc_tlg_decoder.h" +#include "frontend/audio/feature_pipeline.h" +#include "nnet/decodable.h" +#include "nnet/paddle_nnet.h" + +namespace ppspeech { + +struct RecognizerResource { + FeaturePipelineOptions feature_pipeline_opts; + ModelOptions model_opts; + TLGDecoderOptions tlg_opts; + // CTCBeamSearchOptions beam_search_opts; + kaldi::BaseFloat acoustic_scale; + RecognizerResource() + : acoustic_scale(1.0), + feature_pipeline_opts(), + model_opts(), + tlg_opts() {} +}; + +class Recognizer { + public: + explicit Recognizer(const RecognizerResource& resouce); + void Accept(const kaldi::Vector& waves); + void Decode(); + std::string GetFinalResult(); + void SetFinished(); + bool IsFinished(); + void Reset(); + + private: + // std::shared_ptr resource_; + // RecognizerResource resource_; + std::shared_ptr feature_pipeline_; + std::shared_ptr decodable_; + std::unique_ptr decoder_; + bool input_finished_; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/CMakeLists.txt b/speechx/speechx/frontend/audio/CMakeLists.txt index 35243b6e..2d20edf7 100644 --- a/speechx/speechx/frontend/audio/CMakeLists.txt +++ b/speechx/speechx/frontend/audio/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(frontend STATIC linear_spectrogram.cc audio_cache.cc feature_cache.cc + feature_pipeline.cc ) -target_link_libraries(frontend PUBLIC kaldi-matrix) \ No newline at end of file +target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common) diff --git a/speechx/speechx/frontend/audio/audio_cache.cc b/speechx/speechx/frontend/audio/audio_cache.cc index 50aca4fb..e8af6668 100644 --- a/speechx/speechx/frontend/audio/audio_cache.cc +++ b/speechx/speechx/frontend/audio/audio_cache.cc @@ -41,7 +41,7 @@ void AudioCache::Accept(const VectorBase& waves) { ready_feed_condition_.wait(lock); } for (size_t idx = 0; idx < waves.Dim(); ++idx) { - int32 buffer_idx = (idx + offset_) % ring_buffer_.size(); + int32 buffer_idx = (idx + offset_ + size_) % ring_buffer_.size(); ring_buffer_[buffer_idx] = waves(idx); if (convert2PCM32_) ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx)); diff --git a/speechx/speechx/frontend/audio/audio_cache.h b/speechx/speechx/frontend/audio/audio_cache.h index adef1239..a681ef09 100644 --- a/speechx/speechx/frontend/audio/audio_cache.h +++ b/speechx/speechx/frontend/audio/audio_cache.h @@ -24,7 +24,7 @@ namespace ppspeech { class AudioCache : public FrontendInterface { public: explicit AudioCache(int buffer_size = 1000 * kint16max, - bool convert2PCM32 = false); + bool convert2PCM32 = true); virtual void Accept(const kaldi::VectorBase& waves); diff --git a/speechx/speechx/frontend/audio/feature_cache.cc b/speechx/speechx/frontend/audio/feature_cache.cc index 3f7f6502..b5768460 100644 --- a/speechx/speechx/frontend/audio/feature_cache.cc +++ b/speechx/speechx/frontend/audio/feature_cache.cc @@ -23,10 +23,13 @@ using std::vector; using kaldi::SubVector; using std::unique_ptr; -FeatureCache::FeatureCache(int max_size, +FeatureCache::FeatureCache(FeatureCacheOptions opts, unique_ptr base_extractor) { - max_size_ = max_size; + max_size_ = opts.max_size; + frame_chunk_stride_ = opts.frame_chunk_stride; + frame_chunk_size_ = opts.frame_chunk_size; base_extractor_ = std::move(base_extractor); + dim_ = base_extractor_->Dim(); } void FeatureCache::Accept(const kaldi::VectorBase& inputs) { @@ -44,13 +47,14 @@ bool FeatureCache::Read(kaldi::Vector* feats) { std::unique_lock lock(mutex_); while (cache_.empty() && base_extractor_->IsFinished() == false) { - ready_read_condition_.wait(lock); - BaseFloat elapsed = timer.Elapsed() * 1000; - // todo replace 1.0 with timeout_ - if (elapsed > 1.0) { + // todo refactor: wait + // ready_read_condition_.wait(lock); + int32 elapsed = static_cast(timer.Elapsed() * 1000); + // todo replace 1 with timeout_, 1 ms + if (elapsed > 1) { return false; } - usleep(1000); // sleep 1 ms + usleep(100); // sleep 0.1 ms } if (cache_.empty()) return false; feats->Resize(cache_.front().Dim()); @@ -63,25 +67,41 @@ bool FeatureCache::Read(kaldi::Vector* feats) { // read all data from base_feature_extractor_ into cache_ bool FeatureCache::Compute() { // compute and feed - Vector feature_chunk; - bool result = base_extractor_->Read(&feature_chunk); + Vector feature; + bool result = base_extractor_->Read(&feature); + if (result == false || feature.Dim() == 0) return false; + int32 joint_len = feature.Dim() + remained_feature_.Dim(); + int32 num_chunk = + ((joint_len / dim_) - frame_chunk_size_) / frame_chunk_stride_ + 1; - std::unique_lock lock(mutex_); - while (cache_.size() >= max_size_) { - ready_feed_condition_.wait(lock); - } + Vector joint_feature(joint_len); + joint_feature.Range(0, remained_feature_.Dim()) + .CopyFromVec(remained_feature_); + joint_feature.Range(remained_feature_.Dim(), feature.Dim()) + .CopyFromVec(feature); - // feed cache - if (feature_chunk.Dim() != 0) { + for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { + int32 start = chunk_idx * frame_chunk_stride_ * dim_; + Vector feature_chunk(frame_chunk_size_ * dim_); + SubVector tmp(joint_feature.Data() + start, + frame_chunk_size_ * dim_); + feature_chunk.CopyFromVec(tmp); + + std::unique_lock lock(mutex_); + while (cache_.size() >= max_size_) { + ready_feed_condition_.wait(lock); + } + + // feed cache cache_.push(feature_chunk); + ready_read_condition_.notify_one(); } - ready_read_condition_.notify_one(); + int32 remained_feature_len = + joint_len - num_chunk * frame_chunk_stride_ * dim_; + remained_feature_.Resize(remained_feature_len); + remained_feature_.CopyFromVec(joint_feature.Range( + frame_chunk_stride_ * num_chunk * dim_, remained_feature_len)); return result; } -void Reset() { - // std::lock_guard lock(mutex_); - return; -} - } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/feature_cache.h b/speechx/speechx/frontend/audio/feature_cache.h index 99961b5e..607f72c0 100644 --- a/speechx/speechx/frontend/audio/feature_cache.h +++ b/speechx/speechx/frontend/audio/feature_cache.h @@ -19,10 +19,18 @@ namespace ppspeech { +struct FeatureCacheOptions { + int32 max_size; + int32 frame_chunk_size; + int32 frame_chunk_stride; + FeatureCacheOptions() + : max_size(kint16max), frame_chunk_size(1), frame_chunk_stride(1) {} +}; + class FeatureCache : public FrontendInterface { public: explicit FeatureCache( - int32 max_size = kint16max, + FeatureCacheOptions opts, std::unique_ptr base_extractor = NULL); // Feed feats or waves @@ -32,12 +40,15 @@ class FeatureCache : public FrontendInterface { virtual bool Read(kaldi::Vector* feats); // feat dim - virtual size_t Dim() const { return base_extractor_->Dim(); } + virtual size_t Dim() const { return dim_; } virtual void SetFinished() { + // std::unique_lock lock(mutex_); base_extractor_->SetFinished(); + LOG(INFO) << "set finished"; // read the last chunk data Compute(); + // ready_feed_condition_.notify_one(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); } @@ -52,9 +63,13 @@ class FeatureCache : public FrontendInterface { private: bool Compute(); + int32 dim_; size_t max_size_; - std::unique_ptr base_extractor_; + int32 frame_chunk_size_; + int32 frame_chunk_stride_; + kaldi::Vector remained_feature_; + std::unique_ptr base_extractor_; std::mutex mutex_; std::queue> cache_; std::condition_variable ready_feed_condition_; diff --git a/speechx/speechx/frontend/audio/feature_pipeline.cc b/speechx/speechx/frontend/audio/feature_pipeline.cc new file mode 100644 index 00000000..86eca2e0 --- /dev/null +++ b/speechx/speechx/frontend/audio/feature_pipeline.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "frontend/audio/feature_pipeline.h" + +namespace ppspeech { + +using std::unique_ptr; + +FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) { + unique_ptr data_source( + new ppspeech::AudioCache(1000 * kint16max, opts.convert2PCM32)); + + unique_ptr linear_spectrogram( + new ppspeech::LinearSpectrogram(opts.linear_spectrogram_opts, + std::move(data_source))); + + unique_ptr cmvn( + new ppspeech::CMVN(opts.cmvn_file, std::move(linear_spectrogram))); + + base_extractor_.reset( + new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn))); +} + +} // ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/feature_pipeline.h b/speechx/speechx/frontend/audio/feature_pipeline.h new file mode 100644 index 00000000..7bd6c84f --- /dev/null +++ b/speechx/speechx/frontend/audio/feature_pipeline.h @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor later (SGoat) + +#pragma once + +#include "frontend/audio/audio_cache.h" +#include "frontend/audio/data_cache.h" +#include "frontend/audio/feature_cache.h" +#include "frontend/audio/frontend_itf.h" +#include "frontend/audio/linear_spectrogram.h" +#include "frontend/audio/normalizer.h" + +namespace ppspeech { + +struct FeaturePipelineOptions { + std::string cmvn_file; + bool convert2PCM32; + LinearSpectrogramOptions linear_spectrogram_opts; + FeatureCacheOptions feature_cache_opts; + FeaturePipelineOptions() + : cmvn_file(""), + convert2PCM32(false), + linear_spectrogram_opts(), + feature_cache_opts() {} +}; + +class FeaturePipeline : public FrontendInterface { + public: + explicit FeaturePipeline(const FeaturePipelineOptions& opts); + virtual void Accept(const kaldi::VectorBase& waves) { + base_extractor_->Accept(waves); + } + virtual bool Read(kaldi::Vector* feats) { + return base_extractor_->Read(feats); + } + virtual size_t Dim() const { return base_extractor_->Dim(); } + virtual void SetFinished() { base_extractor_->SetFinished(); } + virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + virtual void Reset() { base_extractor_->Reset(); } + + private: + std::unique_ptr base_extractor_; +}; +} \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.cc b/speechx/speechx/frontend/audio/linear_spectrogram.cc index d6ae3d01..9ef5e766 100644 --- a/speechx/speechx/frontend/audio/linear_spectrogram.cc +++ b/speechx/speechx/frontend/audio/linear_spectrogram.cc @@ -52,16 +52,16 @@ bool LinearSpectrogram::Read(Vector* feats) { if (flag == false || input_feats.Dim() == 0) return false; int32 feat_len = input_feats.Dim(); - int32 left_len = reminded_wav_.Dim(); + int32 left_len = remained_wav_.Dim(); Vector waves(feat_len + left_len); - waves.Range(0, left_len).CopyFromVec(reminded_wav_); + waves.Range(0, left_len).CopyFromVec(remained_wav_); waves.Range(left_len, feat_len).CopyFromVec(input_feats); Compute(waves, feats); int32 frame_shift = opts_.frame_opts.WindowShift(); int32 num_frames = kaldi::NumFrames(waves.Dim(), opts_.frame_opts); int32 left_samples = waves.Dim() - frame_shift * num_frames; - reminded_wav_.Resize(left_samples); - reminded_wav_.CopyFromVec( + remained_wav_.Resize(left_samples); + remained_wav_.CopyFromVec( waves.Range(frame_shift * num_frames, left_samples)); return true; } diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.h b/speechx/speechx/frontend/audio/linear_spectrogram.h index 689ec2c4..2764b7cf 100644 --- a/speechx/speechx/frontend/audio/linear_spectrogram.h +++ b/speechx/speechx/frontend/audio/linear_spectrogram.h @@ -25,12 +25,12 @@ struct LinearSpectrogramOptions { kaldi::FrameExtractionOptions frame_opts; kaldi::BaseFloat streaming_chunk; // second - LinearSpectrogramOptions() : streaming_chunk(0.36), frame_opts() {} + LinearSpectrogramOptions() : streaming_chunk(0.1), frame_opts() {} void Register(kaldi::OptionsItf* opts) { opts->Register("streaming-chunk", &streaming_chunk, - "streaming chunk size, default: 0.36 sec"); + "streaming chunk size, default: 0.1 sec"); frame_opts.Register(opts); } }; @@ -48,7 +48,7 @@ class LinearSpectrogram : public FrontendInterface { virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual void Reset() { base_extractor_->Reset(); - reminded_wav_.Resize(0); + remained_wav_.Resize(0); } private: @@ -60,7 +60,7 @@ class LinearSpectrogram : public FrontendInterface { kaldi::BaseFloat hanning_window_energy_; LinearSpectrogramOptions opts_; std::unique_ptr base_extractor_; - kaldi::Vector reminded_wav_; + kaldi::Vector remained_wav_; int chunk_sample_size_; DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); }; diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index 3f5dadd2..465f64a9 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -78,7 +78,6 @@ bool Decodable::AdvanceChunk() { } int32 nnet_dim = 0; Vector inferences; - Matrix nnet_cache_tmp; nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim); nnet_cache_.CopyRowsFromVec(inferences); From dc8efca27248d3bc5b9792ec00979ea6ce642756 Mon Sep 17 00:00:00 2001 From: Yang Zhou Date: Tue, 19 Apr 2022 20:14:02 +0800 Subject: [PATCH 22/31] add test script --- .../ds2_ol/aishell/websocket_client.sh | 37 ++++++++ .../ds2_ol/aishell/websocket_server.sh | 66 ++++++++++++++ .../ds2_ol/decoder/recognizer_test_main.cc | 85 +++++++++++++++++++ 3 files changed, 188 insertions(+) create mode 100644 speechx/examples/ds2_ol/aishell/websocket_client.sh create mode 100644 speechx/examples/ds2_ol/aishell/websocket_server.sh create mode 100644 speechx/examples/ds2_ol/decoder/recognizer_test_main.cc diff --git a/speechx/examples/ds2_ol/aishell/websocket_client.sh b/speechx/examples/ds2_ol/aishell/websocket_client.sh new file mode 100644 index 00000000..3c6b4e91 --- /dev/null +++ b/speechx/examples/ds2_ol/aishell/websocket_client.sh @@ -0,0 +1,37 @@ +#!/bin/bash +set +x +set -e + +. path.sh + +# 1. compile +if [ ! -d ${SPEECHX_EXAMPLES} ]; then + pushd ${SPEECHX_ROOT} + bash build.sh + popd +fi + +# input +mkdir -p data +data=$PWD/data +ckpt_dir=$data/model +model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ +vocb_dir=$ckpt_dir/data/lang_char +# output +aishell_wav_scp=aishell_test.scp +if [ ! -d $data/test ]; then + pushd $data + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip + unzip aishell_test.zip + popd + + realpath $data/test/*/*.wav > $data/wavlist + awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id + paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp +fi + +export GLOG_logtostderr=1 + +# websocket client +websocket_client_main \ + --wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.36 diff --git a/speechx/examples/ds2_ol/aishell/websocket_server.sh b/speechx/examples/ds2_ol/aishell/websocket_server.sh new file mode 100644 index 00000000..ea619d54 --- /dev/null +++ b/speechx/examples/ds2_ol/aishell/websocket_server.sh @@ -0,0 +1,66 @@ +#!/bin/bash +set +x +set -e + +. path.sh + + +# 1. compile +if [ ! -d ${SPEECHX_EXAMPLES} ]; then + pushd ${SPEECHX_ROOT} + bash build.sh + popd +fi + +# input +mkdir -p data +data=$PWD/data +ckpt_dir=$data/model +model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ +vocb_dir=$ckpt_dir/data/lang_char/ + +# output +aishell_wav_scp=aishell_test.scp +if [ ! -d $data/test ]; then + pushd $data + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip + unzip aishell_test.zip + popd + + realpath $data/test/*/*.wav > $data/wavlist + awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id + paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp +fi + + +if [ ! -d $ckpt_dir ]; then + mkdir -p $ckpt_dir + wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz + tar xzfv $ckpt_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir +fi + + +export GLOG_logtostderr=1 + +# 3. gen cmvn +cmvn=$PWD/cmvn.ark +cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn + +text=$data/test/text +graph_dir=./aishell_graph +if [ ! -d $graph_dir ]; then + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip + unzip aishell_graph.zip +fi + +# 5. test websocket server +websocket_server_main \ + --cmvn_file=$cmvn \ + --model_path=$model_dir/avg_1.jit.pdmodel \ + --streaming_chunk=0.1 \ + --convert2PCM32=true \ + --params_path=$model_dir/avg_1.jit.pdiparams \ + --word_symbol_table=$graph_dir/words.txt \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --graph_path=$graph_dir/TLG.fst --max_active=7500 \ + --acoustic_scale=1.2 diff --git a/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc b/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc new file mode 100644 index 00000000..198a8ec2 --- /dev/null +++ b/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/recognizer.h" +#include "decoder/param.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/table-types.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure(); + ppspeech::Recognizer recognizer(resource); + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + int sample_rate = 16000; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + int32 num_done = 0, num_err = 0; + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + std::vector> feats; + int feature_rows = 0; + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + kaldi::Vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk(i) = waveform(sample_offset + i); + } + + recognizer.Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + recognizer.SetFinished(); + } + recognizer.Decode(); + + sample_offset += cur_chunk_size; + } + std::string result; + result = recognizer.GetFinalResult(); + recognizer.Reset(); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + KALDI_LOG << " the result of " << utt << " is empty"; + continue; + } + KALDI_LOG << " the result of " << utt << " is " << result; + result_writer.Write(utt, result); + ++num_done; + } +} \ No newline at end of file From 00febff734bdc538fc8463c2ecf22b9c136b53f4 Mon Sep 17 00:00:00 2001 From: qingen Date: Tue, 19 Apr 2022 20:33:43 +0800 Subject: [PATCH 23/31] [vec][loss] update docstring, test=doc fix #1717 --- paddlespeech/vector/modules/loss.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddlespeech/vector/modules/loss.py b/paddlespeech/vector/modules/loss.py index 015c0dfe..af38dd01 100644 --- a/paddlespeech/vector/modules/loss.py +++ b/paddlespeech/vector/modules/loss.py @@ -132,6 +132,9 @@ class NCELoss(nn.Layer): def forward(self, output, target): """Forward inference + + Args: + output (tensor): the model output, which is the input of loss function """ output = paddle.reshape(output, [-1, self.N]) B = output.shape[0] From 48fa84bee90d8fc8b9f5619f8e22e796b8a10aca Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 20:15:18 +0800 Subject: [PATCH 24/31] fix the asr online client bug, return None, test=doc --- paddlespeech/s2t/modules/encoder.py | 2 -- paddlespeech/server/README.md | 13 +++++++++++++ paddlespeech/server/README_cn.md | 14 ++++++++++++++ paddlespeech/server/bin/paddlespeech_client.py | 6 ++++-- .../server/engine/asr/online/asr_engine.py | 4 ++-- .../server/engine/asr/online/ctc_search.py | 8 +++----- .../server/tests/asr/online/websocket_client.py | 11 ++++------- 7 files changed, 40 insertions(+), 18 deletions(-) diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index 347035cd..c843c0e2 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -317,8 +317,6 @@ class BaseEncoder(nn.Layer): outputs = [] offset = 0 # Feed forward overlap input step by step - print(f"context: {context}") - print(f"stride: {stride}") for cur in range(0, num_frames - context + 1, stride): end = min(cur + decoding_window, num_frames) chunk_xs = xs[:, cur:end, :] diff --git a/paddlespeech/server/README.md b/paddlespeech/server/README.md index 819fe440..3ac68dae 100644 --- a/paddlespeech/server/README.md +++ b/paddlespeech/server/README.md @@ -35,3 +35,16 @@ ```bash paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav ``` + + ## Online ASR Server + +### Lanuch online asr server +``` +paddlespeech_server start --config_file conf/ws_conformer_application.yaml +``` + +### Access online asr server + +``` +paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav +``` \ No newline at end of file diff --git a/paddlespeech/server/README_cn.md b/paddlespeech/server/README_cn.md index c0a4a733..5f235313 100644 --- a/paddlespeech/server/README_cn.md +++ b/paddlespeech/server/README_cn.md @@ -35,3 +35,17 @@ ```bash paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav ``` + +## 流式ASR + +### 启动流式语音识别服务 + +``` +paddlespeech_server start --config_file conf/ws_conformer_application.yaml +``` + +### 访问流式语音识别服务 + +``` +paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav +``` \ No newline at end of file diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index cb802ce5..45469178 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -277,11 +277,12 @@ class ASRClientExecutor(BaseExecutor): lang=lang, audio_format=audio_format) time_end = time.time() - logger.info(res.json()) + logger.info(res) logger.info("Response time %f s." % (time_end - time_start)) return True except Exception as e: logger.error("Failed to speech recognition.") + logger.error(e) return False @stats_wrapper @@ -299,9 +300,10 @@ class ASRClientExecutor(BaseExecutor): logging.info("asr websocket client start") handler = ASRAudioHandler(server_ip, port) loop = asyncio.get_event_loop() - loop.run_until_complete(handler.run(input)) + res = loop.run_until_complete(handler.run(input)) logging.info("asr websocket client finished") + return res['asr_results'] @cli_client_register( name='paddlespeech_client.cls', description='visit cls service') diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 4d15d93b..c79abf1b 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -473,7 +473,7 @@ class PaddleASRConnectionHanddler: ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) - self.searcher.search(None, ctc_probs, self.cached_feat.place) + self.searcher.search(ctc_probs, self.cached_feat.place) self.hyps = self.searcher.get_one_best_hyps() assert self.cached_feat.shape[0] == 1 @@ -823,7 +823,7 @@ class ASRServerExecutor(ASRExecutor): ctc_probs = self.model.ctc.log_softmax( encoder_out) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) - self.searcher.search(xs, ctc_probs, xs.place) + self.searcher.search(ctc_probs, xs.place) # update the one best result self.hyps = self.searcher.get_one_best_hyps() diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index c3822b5c..b1c80c36 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -24,19 +24,18 @@ class CTCPrefixBeamSearch: """Implement the ctc prefix beam search Args: - config (_type_): _description_ + config (yacs.config.CfgNode): _description_ """ self.config = config self.reset() - def search(self, xs, ctc_probs, device, blank_id=0): + def search(self, ctc_probs, device, blank_id=0): """ctc prefix beam search method decode a chunk feature Args: xs (paddle.Tensor): feature data ctc_probs (paddle.Tensor): the ctc probability of all the tokens - encoder_out (paddle.Tensor): _description_ - encoder_mask (_type_): _description_ + device (paddle.fluid.core_avx.Place): the feature host device, such as CUDAPlace(0). blank_id (int, optional): the blank id in the vocab. Defaults to 0. Returns: @@ -45,7 +44,6 @@ class CTCPrefixBeamSearch: # decode logger.info("start to ctc prefix search") - # device = xs.place batch_size = 1 beam_size = self.config.beam_size maxlen = ctc_probs.shape[0] diff --git a/paddlespeech/server/tests/asr/online/websocket_client.py b/paddlespeech/server/tests/asr/online/websocket_client.py index 62e011ce..49cbd703 100644 --- a/paddlespeech/server/tests/asr/online/websocket_client.py +++ b/paddlespeech/server/tests/asr/online/websocket_client.py @@ -34,10 +34,9 @@ class ASRAudioHandler: def read_wave(self, wavfile_path: str): samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') x_len = len(samples) - # chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz - chunk_size = 80 * 16 #80ms, sample_rate = 16kHz - if x_len % chunk_size != 0: + chunk_size = 85 * 16 #80ms, sample_rate = 16kHz + if x_len % chunk_size!= 0: padding_len_x = chunk_size - x_len % chunk_size else: padding_len_x = 0 @@ -48,7 +47,6 @@ class ASRAudioHandler: assert (x_len + padding_len_x) % chunk_size == 0 num_chunk = (x_len + padding_len_x) / chunk_size num_chunk = int(num_chunk) - for i in range(0, num_chunk): start = i * chunk_size end = start + chunk_size @@ -82,7 +80,6 @@ class ASRAudioHandler: msg = json.loads(msg) logging.info("receive msg={}".format(msg)) - result = msg # finished audio_info = json.dumps( { @@ -98,8 +95,8 @@ class ASRAudioHandler: # decode the bytes to str msg = json.loads(msg) - logging.info("receive msg={}".format(msg)) - + logging.info("final receive msg={}".format(msg)) + result = msg return result From 9c03280ca699dbf9837cdedbc0d93d2c11cc9412 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 21:01:13 +0800 Subject: [PATCH 25/31] remove debug info, test=doc --- paddlespeech/s2t/models/u2/u2.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index f0d2711d..9b66126e 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -213,14 +213,12 @@ class U2BaseModel(ASRInterface, nn.Layer): num_decoding_left_chunks=num_decoding_left_chunks ) # (B, maxlen, encoder_dim) else: - print("offline decode from the asr") encoder_out, encoder_mask = self.encoder( speech, speech_lengths, decoding_chunk_size=decoding_chunk_size, num_decoding_left_chunks=num_decoding_left_chunks ) # (B, maxlen, encoder_dim) - print("offline decode success") return encoder_out, encoder_mask def recognize( @@ -281,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer): # TODO(Hui Zhang): if end_flag.sum() == running_size: if end_flag.cast(paddle.int64).sum() == running_size: break - + # 2.1 Forward decoder step hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( running_size, 1, 1).to(device) # (B*N, i, i) # logp: (B*N, vocab) logp, cache = self.decoder.forward_one_step( encoder_out, encoder_mask, hyps, hyps_mask, cache) - # 2.2 First beam prune: select topk best prob at current time top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) top_k_logp = mask_finished_scores(top_k_logp, end_flag) @@ -708,7 +705,6 @@ class U2BaseModel(ASRInterface, nn.Layer): List[List[int]]: transcripts. """ batch_size = feats.shape[0] - print("start to decode the audio feat") if decoding_method in ['ctc_prefix_beam_search', 'attention_rescoring'] and batch_size > 1: logger.error( @@ -716,7 +712,6 @@ class U2BaseModel(ASRInterface, nn.Layer): ) logger.error(f"current batch_size is {batch_size}") sys.exit(1) - print(f"use the {decoding_method} to decode the audio feat") if decoding_method == 'attention': hyps = self.recognize( feats, From a7105ca3e7e177b80bc23967e25b7f80ba046137 Mon Sep 17 00:00:00 2001 From: Yang Zhou Date: Tue, 19 Apr 2022 21:03:24 +0800 Subject: [PATCH 26/31] fix ctc prefix beam binary --- .../ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc b/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc index 21afec27..e145f6ee 100644 --- a/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc +++ b/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc @@ -50,6 +50,9 @@ int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); + CHECK(FLAGS_result_wspecifier != ""); + CHECK(FLAGS_feature_rspecifier != ""); + kaldi::SequentialBaseFloatMatrixReader feature_reader( FLAGS_feature_rspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); From ff4ddd229e8798f31fce71f7e096319d6171ed3f Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Tue, 19 Apr 2022 23:12:46 +0800 Subject: [PATCH 27/31] fix the unuseful code, test=doc --- paddlespeech/s2t/modules/ctc.py | 1 - paddlespeech/server/conf/ws_application.yaml | 8 -------- paddlespeech/server/conf/ws_conformer_application.yaml | 2 +- paddlespeech/server/engine/asr/online/asr_engine.py | 2 +- paddlespeech/server/ws/asr_socket.py | 1 - 5 files changed, 2 insertions(+), 12 deletions(-) diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index bd1219b1..1bb15873 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -181,7 +181,6 @@ class CTCDecoder(CTCDecoderBase): if self._ext_scorer is not None: return - from paddlespeech.s2t.decoders.ctcdecoder import Scorer # noqa: F401 if language_model_path != '': logger.info("begin to initialize the external scorer " "for decoding") diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index dae4a3ff..dee8d78b 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -45,11 +45,3 @@ asr_online: shift_n: 4 # frame window_ms: 20 # ms shift_ms: 10 # ms - - vad_conf: - aggressiveness: 2 - sample_rate: 16000 - frame_duration_ms: 20 - sample_width: 2 - padding_ms: 200 - padding_ratio: 0.9 \ No newline at end of file diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml index 89a861ef..e14833de 100644 --- a/paddlespeech/server/conf/ws_conformer_application.yaml +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -21,7 +21,7 @@ engine_list: ['asr_online'] ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### asr_online: - model_type: 'conformer_online_multi-cn' + model_type: 'conformer_online_multicn' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index c79abf1b..34a028a3 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -59,7 +59,7 @@ pretrained_models = { 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' }, - "conformer_online_multi-cn-zh-16k": { + "conformer_online_multicn-zh-16k": { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', 'md5': diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 82b05bc5..a865703d 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -28,7 +28,6 @@ router = APIRouter() @router.websocket('/ws/asr') async def websocket_endpoint(websocket: WebSocket): - print("websocket protocal receive the dataset") await websocket.accept() engine_pool = get_engine_pool() From ac9fcf7f4a53026bba8efe235d90a0693a70eae6 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Wed, 20 Apr 2022 00:15:37 +0800 Subject: [PATCH 28/31] fix the asr infernece model, paddle.no_grad, test=doc --- paddlespeech/server/engine/asr/online/asr_engine.py | 3 +++ paddlespeech/server/engine/asr/online/ctc_search.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index 34a028a3..758cbaab 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -356,6 +356,7 @@ class PaddleASRConnectionHanddler: else: raise Exception("invalid model name") + @paddle.no_grad() def decode_one_chunk(self, x_chunk, x_chunk_lens): logger.info("start to decoce one chunk with deepspeech2 model") input_names = self.am_predictor.get_input_names() @@ -397,6 +398,7 @@ class PaddleASRConnectionHanddler: logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] + @paddle.no_grad() def advance_decoding(self, is_finished=False): logger.info("start to decode with advanced_decoding method") cfg = self.ctc_decode_config @@ -503,6 +505,7 @@ class PaddleASRConnectionHanddler: else: return '' + @paddle.no_grad() def rescoring(self): if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: return diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index b1c80c36..8aee0a50 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict - +import paddle from paddlespeech.cli.log import logger from paddlespeech.s2t.utils.utility import log_add @@ -29,6 +29,7 @@ class CTCPrefixBeamSearch: self.config = config self.reset() + @paddle.no_grad() def search(self, ctc_probs, device, blank_id=0): """ctc prefix beam search method decode a chunk feature From 6a7245657fa088db7ee1c275ae953695e8d77d39 Mon Sep 17 00:00:00 2001 From: qingen Date: Wed, 20 Apr 2022 11:33:25 +0800 Subject: [PATCH 29/31] [vec][loss] add FocalLoss to deal with class imbalances, test=doc fix #1721 --- paddlespeech/vector/modules/loss.py | 66 ++++++++++++++++++++++- paddlespeech/vector/utils/vector_utils.py | 1 + 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/paddlespeech/vector/modules/loss.py b/paddlespeech/vector/modules/loss.py index af38dd01..9a7530c1 100644 --- a/paddlespeech/vector/modules/loss.py +++ b/paddlespeech/vector/modules/loss.py @@ -132,7 +132,7 @@ class NCELoss(nn.Layer): def forward(self, output, target): """Forward inference - + Args: output (tensor): the model output, which is the input of loss function """ @@ -161,7 +161,7 @@ class NCELoss(nn.Layer): """Post processing the score of post model(output of nn) of batchsize data """ scores = self.get_scores(idx, scores) - scale = paddle.to_tensor([self.Z_offset], dtype='float32') + scale = paddle.to_tensor([self.Z_offset], dtype='float64') scores = paddle.add(scores, -scale) prob = paddle.exp(scores) if sep_target: @@ -225,3 +225,65 @@ class NCELoss(nn.Layer): loss = -(model_loss + noise_loss) return loss + + +class FocalLoss(nn.Layer): + """This criterion is a implemenation of Focal Loss, which is proposed in + Focal Loss for Dense Object Detection. + + Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) + + The losses are averaged across observations for each minibatch. + + Args: + alpha(1D Tensor, Variable) : the scalar factor for this criterion + gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), + putting more focus on hard, misclassified examples + size_average(bool): By default, the losses are averaged over observations for each minibatch. + However, if the field size_average is set to False, the losses are + instead summed for each minibatch. + """ + + def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=-100): + super(FocalLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.size_average = size_average + self.ce = nn.CrossEntropyLoss( + ignore_index=ignore_index, reduction="none") + + def forward(self, outputs, targets): + """Forword inference. + + Args: + outputs: input tensor + target: target label tensor + """ + ce_loss = self.ce(outputs, targets) + pt = paddle.exp(-ce_loss) + focal_loss = self.alpha * (1 - pt)**self.gamma * ce_loss + if self.size_average: + return focal_loss.mean() + else: + return focal_loss.sum() + + +if __name__ == "__main__": + import numpy as np + from paddlespeech.vector.utils.vector_utils import Q_from_tokens + paddle.set_device("cpu") + + input_data = paddle.uniform([5, 100], dtype="float64") + label_data = np.random.randint(0, 100, size=(5)).astype(np.int64) + + input = paddle.to_tensor(input_data) + label = paddle.to_tensor(label_data) + + loss1 = FocalLoss() + loss = loss1.forward(input, label) + print("loss: %.5f" % (loss)) + + Q = Q_from_tokens(100) + loss2 = NCELoss(Q) + loss = loss2.forward(input, label) + print("loss: %.5f" % (loss)) diff --git a/paddlespeech/vector/utils/vector_utils.py b/paddlespeech/vector/utils/vector_utils.py index dcf0f1aa..d6659e3f 100644 --- a/paddlespeech/vector/utils/vector_utils.py +++ b/paddlespeech/vector/utils/vector_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import paddle def get_chunks(seg_dur, audio_id, audio_duration): From c74fa9ada8451d23463fb743dc53efffec82ab51 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Wed, 20 Apr 2022 06:22:46 +0000 Subject: [PATCH 30/31] restructure syn_utils.py, test=tts --- examples/csmsc/tts0/local/inference.sh | 14 +- examples/csmsc/tts3/local/inference.sh | 1 - examples/csmsc/tts3/local/synthesize_e2e.sh | 4 +- examples/ljspeech/tts3/local/synthesize.sh | 2 +- paddlespeech/t2s/exps/inference.py | 31 ++- paddlespeech/t2s/exps/inference_streaming.py | 38 ++- paddlespeech/t2s/exps/ort_predict.py | 14 +- paddlespeech/t2s/exps/ort_predict_e2e.py | 20 +- .../t2s/exps/ort_predict_streaming.py | 34 ++- paddlespeech/t2s/exps/syn_utils.py | 261 ++++++++---------- paddlespeech/t2s/exps/synthesize.py | 24 +- paddlespeech/t2s/exps/synthesize_e2e.py | 36 ++- paddlespeech/t2s/exps/synthesize_streaming.py | 19 +- paddlespeech/t2s/exps/voice_cloning.py | 13 +- paddlespeech/t2s/exps/wavernn/synthesize.py | 3 +- 15 files changed, 300 insertions(+), 214 deletions(-) diff --git a/examples/csmsc/tts0/local/inference.sh b/examples/csmsc/tts0/local/inference.sh index e417d748..d2960441 100755 --- a/examples/csmsc/tts0/local/inference.sh +++ b/examples/csmsc/tts0/local/inference.sh @@ -27,20 +27,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --phones_dict=dump/phone_id_map.txt fi -# style melgan -# style melgan's Dygraph to Static Graph is not ready now -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - python3 ${BIN_DIR}/../inference.py \ - --inference_dir=${train_output_path}/inference \ - --am=tacotron2_csmsc \ - --voc=style_melgan_csmsc \ - --text=${BIN_DIR}/../sentences.txt \ - --output_dir=${train_output_path}/pd_infer_out \ - --phones_dict=dump/phone_id_map.txt -fi - # hifigan -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then python3 ${BIN_DIR}/../inference.py \ --inference_dir=${train_output_path}/inference \ --am=tacotron2_csmsc \ diff --git a/examples/csmsc/tts3/local/inference.sh b/examples/csmsc/tts3/local/inference.sh index 7052b347..b43fd286 100755 --- a/examples/csmsc/tts3/local/inference.sh +++ b/examples/csmsc/tts3/local/inference.sh @@ -28,7 +28,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --phones_dict=dump/phone_id_map.txt fi - # hifigan if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then python3 ${BIN_DIR}/../inference.py \ diff --git a/examples/csmsc/tts3/local/synthesize_e2e.sh b/examples/csmsc/tts3/local/synthesize_e2e.sh index 512e062b..8130eff1 100755 --- a/examples/csmsc/tts3/local/synthesize_e2e.sh +++ b/examples/csmsc/tts3/local/synthesize_e2e.sh @@ -109,6 +109,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --lang=zh \ --text=${BIN_DIR}/../sentences.txt \ --output_dir=${train_output_path}/test_e2e \ - --phones_dict=dump/phone_id_map.txt \ - --inference_dir=${train_output_path}/inference + --phones_dict=dump/phone_id_map.txt #\ + # --inference_dir=${train_output_path}/inference fi diff --git a/examples/ljspeech/tts3/local/synthesize.sh b/examples/ljspeech/tts3/local/synthesize.sh index 6dc34274..0733e96f 100755 --- a/examples/ljspeech/tts3/local/synthesize.sh +++ b/examples/ljspeech/tts3/local/synthesize.sh @@ -26,7 +26,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then fi # hifigan -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then FLAGS_allocator_strategy=naive_best_fit \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \ python3 ${BIN_DIR}/../synthesize.py \ diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 3e7c11f2..7a19a113 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -102,20 +102,31 @@ def parse_args(): def main(): args = parse_args() # frontend - frontend = get_frontend(args) + frontend = get_frontend( + lang=args.lang, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict) # am_predictor - am_predictor = get_predictor(args, filed='am') + am_predictor = get_predictor( + model_dir=args.inference_dir, + model_file=args.am + ".pdmodel", + params_file=args.am + ".pdiparams", + device=args.device) # model: {model_name}_{dataset} am_dataset = args.am[args.am.rindex('_') + 1:] # voc_predictor - voc_predictor = get_predictor(args, filed='voc') + voc_predictor = get_predictor( + model_dir=args.inference_dir, + model_file=args.voc + ".pdmodel", + params_file=args.voc + ".pdiparams", + device=args.device) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - sentences = get_sentences(args) + sentences = get_sentences(text_file=args.text, lang=args.lang) merge_sentences = True fs = 24000 if am_dataset != 'ljspeech' else 22050 @@ -123,11 +134,13 @@ def main(): for utt_id, sentence in sentences[:3]: with timer() as t: am_output_data = get_am_output( - args, + input=sentence, am_predictor=am_predictor, + am=args.am, frontend=frontend, + lang=args.lang, merge_sentences=merge_sentences, - input=sentence) + speaker_dict=args.speaker_dict, ) wav = get_voc_output( voc_predictor=voc_predictor, input=am_output_data) speed = wav.size / t.elapse @@ -143,11 +156,13 @@ def main(): for utt_id, sentence in sentences: with timer() as t: am_output_data = get_am_output( - args, + input=sentence, am_predictor=am_predictor, + am=args.am, frontend=frontend, + lang=args.lang, merge_sentences=merge_sentences, - input=sentence) + speaker_dict=args.speaker_dict, ) wav = get_voc_output( voc_predictor=voc_predictor, input=am_output_data) diff --git a/paddlespeech/t2s/exps/inference_streaming.py b/paddlespeech/t2s/exps/inference_streaming.py index 0e58056c..ef6d1a4a 100644 --- a/paddlespeech/t2s/exps/inference_streaming.py +++ b/paddlespeech/t2s/exps/inference_streaming.py @@ -25,7 +25,6 @@ from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_predictor from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output -from paddlespeech.t2s.exps.syn_utils import get_streaming_am_predictor from paddlespeech.t2s.exps.syn_utils import get_voc_output from paddlespeech.t2s.utils import str2bool @@ -102,22 +101,43 @@ def parse_args(): def main(): args = parse_args() # frontend - frontend = get_frontend(args) + frontend = get_frontend( + lang=args.lang, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict) # am_predictor - am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor = get_streaming_am_predictor( - args) + + am_encoder_infer_predictor = get_predictor( + model_dir=args.inference_dir, + model_file=args.am + "_am_encoder_infer" + ".pdmodel", + params_file=args.am + "_am_encoder_infer" + ".pdiparams", + device=args.device) + am_decoder_predictor = get_predictor( + model_dir=args.inference_dir, + model_file=args.am + "_am_decoder" + ".pdmodel", + params_file=args.am + "_am_decoder" + ".pdiparams", + device=args.device) + am_postnet_predictor = get_predictor( + model_dir=args.inference_dir, + model_file=args.am + "_am_postnet" + ".pdmodel", + params_file=args.am + "_am_postnet" + ".pdiparams", + device=args.device) am_mu, am_std = np.load(args.am_stat) # model: {model_name}_{dataset} am_dataset = args.am[args.am.rindex('_') + 1:] # voc_predictor - voc_predictor = get_predictor(args, filed='voc') + voc_predictor = get_predictor( + model_dir=args.inference_dir, + model_file=args.voc + ".pdmodel", + params_file=args.voc + ".pdiparams", + device=args.device) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - sentences = get_sentences(args) + sentences = get_sentences(text_file=args.text, lang=args.lang) merge_sentences = True @@ -126,13 +146,13 @@ def main(): for utt_id, sentence in sentences[:3]: with timer() as t: normalized_mel = get_streaming_am_output( - args, + input=sentence, am_encoder_infer_predictor=am_encoder_infer_predictor, am_decoder_predictor=am_decoder_predictor, am_postnet_predictor=am_postnet_predictor, frontend=frontend, - merge_sentences=merge_sentences, - input=sentence) + lang=args.lang, + merge_sentences=merge_sentences, ) mel = denorm(normalized_mel, am_mu, am_std) wav = get_voc_output(voc_predictor=voc_predictor, input=mel) speed = wav.size / t.elapse diff --git a/paddlespeech/t2s/exps/ort_predict.py b/paddlespeech/t2s/exps/ort_predict.py index d1f03710..adbd6809 100644 --- a/paddlespeech/t2s/exps/ort_predict.py +++ b/paddlespeech/t2s/exps/ort_predict.py @@ -30,7 +30,7 @@ def ort_predict(args): test_metadata = list(reader) am_name = args.am[:args.am.rindex('_')] am_dataset = args.am[args.am.rindex('_') + 1:] - test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset) + test_dataset = get_test_dataset(test_metadata=test_metadata, am=args.am) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -38,10 +38,18 @@ def ort_predict(args): fs = 24000 if am_dataset != 'ljspeech' else 22050 # am - am_sess = get_sess(args, filed='am') + am_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.am + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) # vocoder - voc_sess = get_sess(args, filed='voc') + voc_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.voc + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) # am warmup for T in [27, 38, 54]: diff --git a/paddlespeech/t2s/exps/ort_predict_e2e.py b/paddlespeech/t2s/exps/ort_predict_e2e.py index 366a2902..ae5e900b 100644 --- a/paddlespeech/t2s/exps/ort_predict_e2e.py +++ b/paddlespeech/t2s/exps/ort_predict_e2e.py @@ -27,21 +27,31 @@ from paddlespeech.t2s.utils import str2bool def ort_predict(args): # frontend - frontend = get_frontend(args) + frontend = get_frontend( + lang=args.lang, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - sentences = get_sentences(args) + sentences = get_sentences(text_file=args.text, lang=args.lang) am_name = args.am[:args.am.rindex('_')] am_dataset = args.am[args.am.rindex('_') + 1:] fs = 24000 if am_dataset != 'ljspeech' else 22050 - # am - am_sess = get_sess(args, filed='am') + am_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.am + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) # vocoder - voc_sess = get_sess(args, filed='voc') + voc_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.voc + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) # frontend warmup # Loading model cost 0.5+ seconds diff --git a/paddlespeech/t2s/exps/ort_predict_streaming.py b/paddlespeech/t2s/exps/ort_predict_streaming.py index 1b486d19..5568ed39 100644 --- a/paddlespeech/t2s/exps/ort_predict_streaming.py +++ b/paddlespeech/t2s/exps/ort_predict_streaming.py @@ -23,30 +23,50 @@ from paddlespeech.t2s.exps.syn_utils import get_chunks from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sess -from paddlespeech.t2s.exps.syn_utils import get_streaming_am_sess from paddlespeech.t2s.utils import str2bool def ort_predict(args): # frontend - frontend = get_frontend(args) + frontend = get_frontend( + lang=args.lang, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - sentences = get_sentences(args) + sentences = get_sentences(text_file=args.text, lang=args.lang) am_name = args.am[:args.am.rindex('_')] am_dataset = args.am[args.am.rindex('_') + 1:] fs = 24000 if am_dataset != 'ljspeech' else 22050 - # am - am_encoder_infer_sess, am_decoder_sess, am_postnet_sess = get_streaming_am_sess( - args) + # streaming acoustic model + am_encoder_infer_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.am + "_am_encoder_infer" + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) + am_decoder_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.am + "_am_decoder" + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) + + am_postnet_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.am + "_am_postnet" + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) am_mu, am_std = np.load(args.am_stat) # vocoder - voc_sess = get_sess(args, filed='voc') + voc_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.voc + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) # frontend warmup # Loading model cost 0.5+ seconds diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index 21aa5bf8..ce0aee05 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -14,6 +14,10 @@ import math import os from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional import numpy as np import onnxruntime as ort @@ -21,6 +25,7 @@ import paddle from paddle import inference from paddle import jit from paddle.static import InputSpec +from yacs.config import CfgNode from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.t2s.datasets.data_table import DataTable @@ -70,7 +75,7 @@ def denorm(data, mean, std): return data * std + mean -def get_chunks(data, chunk_size, pad_size): +def get_chunks(data, chunk_size: int, pad_size: int): data_len = data.shape[1] chunks = [] n = math.ceil(data_len / chunk_size) @@ -82,28 +87,34 @@ def get_chunks(data, chunk_size, pad_size): # input -def get_sentences(args): +def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'): # construct dataset for evaluation sentences = [] - with open(args.text, 'rt') as f: + with open(text_file, 'rt') as f: for line in f: items = line.strip().split() utt_id = items[0] - if 'lang' in args and args.lang == 'zh': + if lang == 'zh': sentence = "".join(items[1:]) - elif 'lang' in args and args.lang == 'en': + elif lang == 'en': sentence = " ".join(items[1:]) sentences.append((utt_id, sentence)) return sentences -def get_test_dataset(args, test_metadata, am_name, am_dataset): +def get_test_dataset(test_metadata: List[Dict[str, Any]], + am: str, + speaker_dict: Optional[os.PathLike]=None, + voice_cloning: bool=False): + # model: {model_name}_{dataset} + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] if am_name == 'fastspeech2': fields = ["utt_id", "text"] - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: print("multiple speaker fastspeech2!") fields += ["spk_id"] - elif 'voice_cloning' in args and args.voice_cloning: + elif voice_cloning: print("voice cloning!") fields += ["spk_emb"] else: @@ -112,7 +123,7 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset): fields = ["utt_id", "phones", "tones"] elif am_name == 'tacotron2': fields = ["utt_id", "text"] - if 'voice_cloning' in args and args.voice_cloning: + if voice_cloning: print("voice cloning!") fields += ["spk_emb"] @@ -121,12 +132,14 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset): # frontend -def get_frontend(args): - if 'lang' in args and args.lang == 'zh': +def get_frontend(lang: str='zh', + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None): + if lang == 'zh': frontend = Frontend( - phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) - elif 'lang' in args and args.lang == 'en': - frontend = English(phone_vocab_path=args.phones_dict) + phone_vocab_path=phones_dict, tone_vocab_path=tones_dict) + elif lang == 'en': + frontend = English(phone_vocab_path=phones_dict) else: print("wrong lang!") print("frontend done!") @@ -134,30 +147,37 @@ def get_frontend(args): # dygraph -def get_am_inference(args, am_config): - with open(args.phones_dict, "r") as f: +def get_am_inference( + am: str='fastspeech2_csmsc', + am_config: CfgNode=None, + am_ckpt: Optional[os.PathLike]=None, + am_stat: Optional[os.PathLike]=None, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, ): + with open(phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] vocab_size = len(phn_id) print("vocab_size:", vocab_size) tone_size = None - if 'tones_dict' in args and args.tones_dict: - with open(args.tones_dict, "r") as f: + if tones_dict is not None: + with open(tones_dict, "r") as f: tone_id = [line.strip().split() for line in f.readlines()] tone_size = len(tone_id) print("tone_size:", tone_size) spk_num = None - if 'speaker_dict' in args and args.speaker_dict: - with open(args.speaker_dict, 'rt') as f: + if speaker_dict is not None: + with open(speaker_dict, 'rt') as f: spk_id = [line.strip().split() for line in f.readlines()] spk_num = len(spk_id) print("spk_num:", spk_num) odim = am_config.n_mels # model: {model_name}_{dataset} - am_name = args.am[:args.am.rindex('_')] - am_dataset = args.am[args.am.rindex('_') + 1:] + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] am_class = dynamic_import(am_name, model_alias) am_inference_class = dynamic_import(am_name + '_inference', model_alias) @@ -174,34 +194,38 @@ def get_am_inference(args, am_config): elif am_name == 'tacotron2': am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) - am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) + am.set_state_dict(paddle.load(am_ckpt)["main_params"]) am.eval() - am_mu, am_std = np.load(args.am_stat) + am_mu, am_std = np.load(am_stat) am_mu = paddle.to_tensor(am_mu) am_std = paddle.to_tensor(am_std) am_normalizer = ZScore(am_mu, am_std) am_inference = am_inference_class(am_normalizer, am) am_inference.eval() print("acoustic model done!") - return am_inference, am_name, am_dataset + return am_inference -def get_voc_inference(args, voc_config): +def get_voc_inference( + voc: str='pwgan_csmsc', + voc_config: Optional[os.PathLike]=None, + voc_ckpt: Optional[os.PathLike]=None, + voc_stat: Optional[os.PathLike]=None, ): # model: {model_name}_{dataset} - voc_name = args.voc[:args.voc.rindex('_')] + voc_name = voc[:voc.rindex('_')] voc_class = dynamic_import(voc_name, model_alias) voc_inference_class = dynamic_import(voc_name + '_inference', model_alias) if voc_name != 'wavernn': voc = voc_class(**voc_config["generator_params"]) - voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"]) + voc.set_state_dict(paddle.load(voc_ckpt)["generator_params"]) voc.remove_weight_norm() voc.eval() else: voc = voc_class(**voc_config["model"]) - voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"]) + voc.set_state_dict(paddle.load(voc_ckpt)["main_params"]) voc.eval() - voc_mu, voc_std = np.load(args.voc_stat) + voc_mu, voc_std = np.load(voc_stat) voc_mu = paddle.to_tensor(voc_mu) voc_std = paddle.to_tensor(voc_std) voc_normalizer = ZScore(voc_mu, voc_std) @@ -211,10 +235,16 @@ def get_voc_inference(args, voc_config): return voc_inference -# to static -def am_to_static(args, am_inference, am_name, am_dataset): +# dygraph to static graph +def am_to_static(am_inference, + am: str='fastspeech2_csmsc', + inference_dir=Optional[os.PathLike], + speaker_dict: Optional[os.PathLike]=None): + # model: {model_name}_{dataset} + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] if am_name == 'fastspeech2': - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: am_inference = jit.to_static( am_inference, input_spec=[ @@ -226,7 +256,7 @@ def am_to_static(args, am_inference, am_name, am_dataset): am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) elif am_name == 'speedyspeech': - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: am_inference = jit.to_static( am_inference, input_spec=[ @@ -247,56 +277,64 @@ def am_to_static(args, am_inference, am_name, am_dataset): am_inference = jit.to_static( am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) - paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am)) - am_inference = paddle.jit.load(os.path.join(args.inference_dir, args.am)) + paddle.jit.save(am_inference, os.path.join(inference_dir, am)) + am_inference = paddle.jit.load(os.path.join(inference_dir, am)) return am_inference -def voc_to_static(args, voc_inference): +def voc_to_static(voc_inference, + voc: str='pwgan_csmsc', + inference_dir=Optional[os.PathLike]): voc_inference = jit.to_static( voc_inference, input_spec=[ InputSpec([-1, 80], dtype=paddle.float32), ]) - paddle.jit.save(voc_inference, os.path.join(args.inference_dir, args.voc)) - voc_inference = paddle.jit.load(os.path.join(args.inference_dir, args.voc)) + paddle.jit.save(voc_inference, os.path.join(inference_dir, voc)) + voc_inference = paddle.jit.load(os.path.join(inference_dir, voc)) return voc_inference # inference -def get_predictor(args, filed='am'): - full_name = '' - if filed == 'am': - full_name = args.am - elif filed == 'voc': - full_name = args.voc +def get_predictor(model_dir: Optional[os.PathLike]=None, + model_file: Optional[os.PathLike]=None, + params_file: Optional[os.PathLike]=None, + device: str='cpu'): + config = inference.Config( - str(Path(args.inference_dir) / (full_name + ".pdmodel")), - str(Path(args.inference_dir) / (full_name + ".pdiparams"))) - if args.device == "gpu": + str(Path(model_dir) / model_file), str(Path(model_dir) / params_file)) + if device == "gpu": config.enable_use_gpu(100, 0) - elif args.device == "cpu": + elif device == "cpu": config.disable_gpu() config.enable_memory_optim() predictor = inference.create_predictor(config) return predictor -def get_am_output(args, am_predictor, frontend, merge_sentences, input): - am_name = args.am[:args.am.rindex('_')] - am_dataset = args.am[args.am.rindex('_') + 1:] +def get_am_output( + input: str, + am_predictor, + am, + frontend, + lang: str='zh', + merge_sentences: bool=True, + speaker_dict: Optional[os.PathLike]=None, + spk_id: int=0, ): + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] am_input_names = am_predictor.get_input_names() get_tone_ids = False get_spk_id = False if am_name == 'speedyspeech': get_tone_ids = True - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + if am_dataset in {"aishell3", "vctk"} and speaker_dict: get_spk_id = True - spk_id = np.array([args.spk_id]) - if args.lang == 'zh': + spk_id = np.array([spk_id]) + if lang == 'zh': input_ids = frontend.get_input_ids( input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] - elif args.lang == 'en': + elif lang == 'en': input_ids = frontend.get_input_ids( input, merge_sentences=merge_sentences) phone_ids = input_ids["phone_ids"] @@ -338,50 +376,6 @@ def get_voc_output(voc_predictor, input): return wav -# streaming am -def get_streaming_am_predictor(args): - full_name = args.am - am_encoder_infer_config = inference.Config( - str( - Path(args.inference_dir) / - (full_name + "_am_encoder_infer" + ".pdmodel")), - str( - Path(args.inference_dir) / - (full_name + "_am_encoder_infer" + ".pdiparams"))) - am_decoder_config = inference.Config( - str( - Path(args.inference_dir) / - (full_name + "_am_decoder" + ".pdmodel")), - str( - Path(args.inference_dir) / - (full_name + "_am_decoder" + ".pdiparams"))) - am_postnet_config = inference.Config( - str( - Path(args.inference_dir) / - (full_name + "_am_postnet" + ".pdmodel")), - str( - Path(args.inference_dir) / - (full_name + "_am_postnet" + ".pdiparams"))) - if args.device == "gpu": - am_encoder_infer_config.enable_use_gpu(100, 0) - am_decoder_config.enable_use_gpu(100, 0) - am_postnet_config.enable_use_gpu(100, 0) - elif args.device == "cpu": - am_encoder_infer_config.disable_gpu() - am_decoder_config.disable_gpu() - am_postnet_config.disable_gpu() - - am_encoder_infer_config.enable_memory_optim() - am_decoder_config.enable_memory_optim() - am_postnet_config.enable_memory_optim() - - am_encoder_infer_predictor = inference.create_predictor( - am_encoder_infer_config) - am_decoder_predictor = inference.create_predictor(am_decoder_config) - am_postnet_predictor = inference.create_predictor(am_postnet_config) - return am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor - - def get_am_sublayer_output(am_sublayer_predictor, input): am_sublayer_input_names = am_sublayer_predictor.get_input_names() input_handle = am_sublayer_predictor.get_input_handle( @@ -397,11 +391,15 @@ def get_am_sublayer_output(am_sublayer_predictor, input): return am_sublayer_output -def get_streaming_am_output(args, am_encoder_infer_predictor, - am_decoder_predictor, am_postnet_predictor, - frontend, merge_sentences, input): +def get_streaming_am_output(input: str, + am_encoder_infer_predictor, + am_decoder_predictor, + am_postnet_predictor, + frontend, + lang: str='zh', + merge_sentences: bool=True): get_tone_ids = False - if args.lang == 'zh': + if lang == 'zh': input_ids = frontend.get_input_ids( input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] @@ -423,58 +421,27 @@ def get_streaming_am_output(args, am_encoder_infer_predictor, return normalized_mel -def get_sess(args, filed='am'): - full_name = '' - if filed == 'am': - full_name = args.am - elif filed == 'voc': - full_name = args.voc - model_dir = str(Path(args.inference_dir) / (full_name + ".onnx")) +# onnx +def get_sess(model_dir: Optional[os.PathLike]=None, + model_file: Optional[os.PathLike]=None, + device: str='cpu', + cpu_threads: int=1, + use_trt: bool=False): + + model_dir = str(Path(model_dir) / model_file) sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - if args.device == "gpu": + if device == "gpu": # fastspeech2/mb_melgan can't use trt now! - if args.use_trt: + if use_trt: providers = ['TensorrtExecutionProvider'] else: providers = ['CUDAExecutionProvider'] - elif args.device == "cpu": + elif device == "cpu": providers = ['CPUExecutionProvider'] - sess_options.intra_op_num_threads = args.cpu_threads + sess_options.intra_op_num_threads = cpu_threads sess = ort.InferenceSession( model_dir, providers=providers, sess_options=sess_options) return sess - - -# streaming am -def get_streaming_am_sess(args): - full_name = args.am - am_encoder_infer_model_dir = str( - Path(args.inference_dir) / (full_name + "_am_encoder_infer" + ".onnx")) - am_decoder_model_dir = str( - Path(args.inference_dir) / (full_name + "_am_decoder" + ".onnx")) - am_postnet_model_dir = str( - Path(args.inference_dir) / (full_name + "_am_postnet" + ".onnx")) - sess_options = ort.SessionOptions() - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - if args.device == "gpu": - # fastspeech2/mb_melgan can't use trt now! - if args.use_trt: - providers = ['TensorrtExecutionProvider'] - else: - providers = ['CUDAExecutionProvider'] - elif args.device == "cpu": - providers = ['CPUExecutionProvider'] - sess_options.intra_op_num_threads = args.cpu_threads - am_encoder_infer_sess = ort.InferenceSession( - am_encoder_infer_model_dir, - providers=providers, - sess_options=sess_options) - am_decoder_sess = ort.InferenceSession( - am_decoder_model_dir, providers=providers, sess_options=sess_options) - am_postnet_sess = ort.InferenceSession( - am_postnet_model_dir, providers=providers, sess_options=sess_options) - return am_encoder_infer_sess, am_decoder_sess, am_postnet_sess diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py index abb1eb4e..dd66e54e 100644 --- a/paddlespeech/t2s/exps/synthesize.py +++ b/paddlespeech/t2s/exps/synthesize.py @@ -50,11 +50,29 @@ def evaluate(args): print(voc_config) # acoustic model - am_inference, am_name, am_dataset = get_am_inference(args, am_config) - test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset) + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + + am_inference = get_am_inference( + am=args.am, + am_config=am_config, + am_ckpt=args.am_ckpt, + am_stat=args.am_stat, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict, + speaker_dict=args.speaker_dict) + test_dataset = get_test_dataset( + test_metadata=test_metadata, + am=args.am, + speaker_dict=args.speaker_dict, + voice_cloning=args.voice_cloning) # vocoder - voc_inference = get_voc_inference(args, voc_config) + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 6c28dc48..2f14ef56 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -42,24 +42,48 @@ def evaluate(args): print(am_config) print(voc_config) - sentences = get_sentences(args) + sentences = get_sentences(text_file=args.text, lang=args.lang) # frontend - frontend = get_frontend(args) + frontend = get_frontend( + lang=args.lang, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict) # acoustic model - am_inference, am_name, am_dataset = get_am_inference(args, am_config) + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + + am_inference = get_am_inference( + am=args.am, + am_config=am_config, + am_ckpt=args.am_ckpt, + am_stat=args.am_stat, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict, + speaker_dict=args.speaker_dict) # vocoder - voc_inference = get_voc_inference(args, voc_config) + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) # whether dygraph to static if args.inference_dir: # acoustic model - am_inference = am_to_static(args, am_inference, am_name, am_dataset) + am_inference = am_to_static( + am_inference=am_inference, + am=args.am, + inference_dir=args.inference_dir, + speaker_dict=args.speaker_dict) # vocoder - voc_inference = voc_to_static(args, voc_inference) + voc_inference = voc_to_static( + voc_inference=voc_inference, + voc=args.voc, + inference_dir=args.inference_dir) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/paddlespeech/t2s/exps/synthesize_streaming.py b/paddlespeech/t2s/exps/synthesize_streaming.py index 4f7a84e9..3659cb49 100644 --- a/paddlespeech/t2s/exps/synthesize_streaming.py +++ b/paddlespeech/t2s/exps/synthesize_streaming.py @@ -49,10 +49,13 @@ def evaluate(args): print(am_config) print(voc_config) - sentences = get_sentences(args) + sentences = get_sentences(text_file=args.text, lang=args.lang) # frontend - frontend = get_frontend(args) + frontend = get_frontend( + lang=args.lang, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict) with open(args.phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] @@ -60,7 +63,6 @@ def evaluate(args): print("vocab_size:", vocab_size) # acoustic model, only support fastspeech2 here now! - # am_inference, am_name, am_dataset = get_am_inference(args, am_config) # model: {model_name}_{dataset} am_name = args.am[:args.am.rindex('_')] am_dataset = args.am[args.am.rindex('_') + 1:] @@ -80,7 +82,11 @@ def evaluate(args): am_postnet = am.postnet # vocoder - voc_inference = get_voc_inference(args, voc_config) + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) # whether dygraph to static if args.inference_dir: @@ -115,7 +121,10 @@ def evaluate(args): os.path.join(args.inference_dir, args.am + "_am_postnet")) # vocoder - voc_inference = voc_to_static(args, voc_inference) + voc_inference = voc_to_static( + voc_inference=voc_inference, + voc=args.voc, + inference_dir=args.inference_dir) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/paddlespeech/t2s/exps/voice_cloning.py b/paddlespeech/t2s/exps/voice_cloning.py index 1afd21df..9257b07d 100644 --- a/paddlespeech/t2s/exps/voice_cloning.py +++ b/paddlespeech/t2s/exps/voice_cloning.py @@ -66,10 +66,19 @@ def voice_cloning(args): print("frontend done!") # acoustic model - am_inference, *_ = get_am_inference(args, am_config) + am_inference = get_am_inference( + am=args.am, + am_config=am_config, + am_ckpt=args.am_ckpt, + am_stat=args.am_stat, + phones_dict=args.phones_dict) # vocoder - voc_inference = get_voc_inference(args, voc_config) + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/paddlespeech/t2s/exps/wavernn/synthesize.py b/paddlespeech/t2s/exps/wavernn/synthesize.py index d23e9cb7..ea48a617 100644 --- a/paddlespeech/t2s/exps/wavernn/synthesize.py +++ b/paddlespeech/t2s/exps/wavernn/synthesize.py @@ -58,8 +58,7 @@ def main(): else: print("ngpu should >= 0 !") - model = WaveRNN( - hop_length=config.n_shift, sample_rate=config.fs, **config["model"]) + model = WaveRNN(**config["model"]) state_dict = paddle.load(args.checkpoint) model.set_state_dict(state_dict["main_params"]) From 4646f7cc8de954497e9edc8ff10ca95b171d8fdb Mon Sep 17 00:00:00 2001 From: TianYuan Date: Wed, 20 Apr 2022 09:48:46 +0000 Subject: [PATCH 31/31] add paddle device set for ort and inference, test=doc --- paddlespeech/t2s/exps/inference.py | 4 ++++ paddlespeech/t2s/exps/inference_streaming.py | 4 ++++ paddlespeech/t2s/exps/ort_predict.py | 4 ++++ paddlespeech/t2s/exps/ort_predict_e2e.py | 3 +++ paddlespeech/t2s/exps/ort_predict_streaming.py | 3 +++ 5 files changed, 18 insertions(+) diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 7a19a113..98e73e10 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -14,6 +14,7 @@ import argparse from pathlib import Path +import paddle import soundfile as sf from timer import timer @@ -101,6 +102,9 @@ def parse_args(): # only inference for models trained with csmsc now def main(): args = parse_args() + + paddle.set_device(args.device) + # frontend frontend = get_frontend( lang=args.lang, diff --git a/paddlespeech/t2s/exps/inference_streaming.py b/paddlespeech/t2s/exps/inference_streaming.py index ef6d1a4a..b680f19a 100644 --- a/paddlespeech/t2s/exps/inference_streaming.py +++ b/paddlespeech/t2s/exps/inference_streaming.py @@ -15,6 +15,7 @@ import argparse from pathlib import Path import numpy as np +import paddle import soundfile as sf from timer import timer @@ -100,6 +101,9 @@ def parse_args(): # only inference for models trained with csmsc now def main(): args = parse_args() + + paddle.set_device(args.device) + # frontend frontend = get_frontend( lang=args.lang, diff --git a/paddlespeech/t2s/exps/ort_predict.py b/paddlespeech/t2s/exps/ort_predict.py index adbd6809..2e8596de 100644 --- a/paddlespeech/t2s/exps/ort_predict.py +++ b/paddlespeech/t2s/exps/ort_predict.py @@ -16,6 +16,7 @@ from pathlib import Path import jsonlines import numpy as np +import paddle import soundfile as sf from timer import timer @@ -25,6 +26,7 @@ from paddlespeech.t2s.utils import str2bool def ort_predict(args): + # construct dataset for evaluation with jsonlines.open(args.test_metadata, 'r') as reader: test_metadata = list(reader) @@ -143,6 +145,8 @@ def parse_args(): def main(): args = parse_args() + paddle.set_device(args.device) + ort_predict(args) diff --git a/paddlespeech/t2s/exps/ort_predict_e2e.py b/paddlespeech/t2s/exps/ort_predict_e2e.py index ae5e900b..a2ef8e4c 100644 --- a/paddlespeech/t2s/exps/ort_predict_e2e.py +++ b/paddlespeech/t2s/exps/ort_predict_e2e.py @@ -15,6 +15,7 @@ import argparse from pathlib import Path import numpy as np +import paddle import soundfile as sf from timer import timer @@ -178,6 +179,8 @@ def parse_args(): def main(): args = parse_args() + paddle.set_device(args.device) + ort_predict(args) diff --git a/paddlespeech/t2s/exps/ort_predict_streaming.py b/paddlespeech/t2s/exps/ort_predict_streaming.py index 5568ed39..5d2c66bc 100644 --- a/paddlespeech/t2s/exps/ort_predict_streaming.py +++ b/paddlespeech/t2s/exps/ort_predict_streaming.py @@ -15,6 +15,7 @@ import argparse from pathlib import Path import numpy as np +import paddle import soundfile as sf from timer import timer @@ -246,6 +247,8 @@ def parse_args(): def main(): args = parse_args() + paddle.set_device(args.device) + ort_predict(args)