Merge branch 'develop' into ngram

pull/1729/head
Hui Zhang 2 years ago committed by GitHub
commit c938a450b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -27,20 +27,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--phones_dict=dump/phone_id_map.txt
fi
# style melgan
# style melgan's Dygraph to Static Graph is not ready now
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \
--voc=style_melgan_csmsc \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt
fi
# hifigan
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \

@ -28,7 +28,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--phones_dict=dump/phone_id_map.txt
fi
# hifigan
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \

@ -109,6 +109,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--lang=zh \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \
--inference_dir=${train_output_path}/inference
--phones_dict=dump/phone_id_map.txt #\
# --inference_dir=${train_output_path}/inference
fi

@ -26,7 +26,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi
# hifigan
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \

@ -40,7 +40,6 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor']
@cli_register(
name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor):
@ -125,6 +124,7 @@ class ASRExecutor(BaseExecutor):
"""
Init model and other resources from a specific path.
"""
logger.info("start to init the model")
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
return
@ -140,14 +140,15 @@ class ASRExecutor(BaseExecutor):
res_path,
self.pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
@ -176,7 +177,6 @@ class ASRExecutor(BaseExecutor):
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.config.decode.decoding_method = decode_method
else:
raise Exception("wrong type")
model_name = model_type[:model_type.rindex(
@ -254,12 +254,14 @@ class ASRExecutor(BaseExecutor):
else:
raise Exception("wrong type")
logger.info("audio feat process success")
@paddle.no_grad()
def infer(self, model_type: str):
"""
Model inference and result stored in self.output.
"""
logger.info("start to infer the model to get the output")
cfg = self.config.decode
audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"]
@ -276,17 +278,22 @@ class ASRExecutor(BaseExecutor):
self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type:
result_transcripts = self.model.decode(
audio,
audio_len,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
self._outputs["result"] = result_transcripts[0][0]
logger.info(f"we will use the transformer like model : {model_type}")
try:
result_transcripts = self.model.decode(
audio,
audio_len,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size,
ctc_weight=cfg.ctc_weight,
decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
self._outputs["result"] = result_transcripts[0][0]
except Exception as e:
logger.exception(e)
else:
raise Exception("invalid model name")

@ -88,6 +88,8 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer":
"paddlespeech.s2t.models.u2:U2Model",
"conformer_online":
"paddlespeech.s2t.models.u2:U2Model",
"transformer":
"paddlespeech.s2t.models.u2:U2Model",
"wenetspeech":

@ -279,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
# logp: (B*N, vocab)
logp, cache = self.decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
@ -708,11 +707,11 @@ class U2BaseModel(ASRInterface, nn.Layer):
batch_size = feats.shape[0]
if decoding_method in ['ctc_prefix_beam_search',
'attention_rescoring'] and batch_size > 1:
logger.fatal(
logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1'
)
logger.error(f"current batch_size is {batch_size}")
sys.exit(1)
if decoding_method == 'attention':
hyps = self.recognize(
feats,

@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase):
# init once
if self._ext_scorer is not None:
return
if language_model_path != '':
logger.info("begin to initialize the external scorer "
"for decoding")

@ -35,3 +35,16 @@
```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
```
## Online ASR Server
### Lanuch online asr server
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### Access online asr server
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
```

@ -35,3 +35,17 @@
```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
```
## 流式ASR
### 启动流式语音识别服务
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### 访问流式语音识别服务
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav
```

@ -277,11 +277,12 @@ class ASRClientExecutor(BaseExecutor):
lang=lang,
audio_format=audio_format)
time_end = time.time()
logger.info(res.json())
logger.info(res)
logger.info("Response time %f s." % (time_end - time_start))
return True
except Exception as e:
logger.error("Failed to speech recognition.")
logger.error(e)
return False
@stats_wrapper
@ -299,9 +300,10 @@ class ASRClientExecutor(BaseExecutor):
logging.info("asr websocket client start")
handler = ASRAudioHandler(server_ip, port)
loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(input))
res = loop.run_until_complete(handler.run(input))
logging.info("asr websocket client finished")
return res['asr_results']
@cli_client_register(
name='paddlespeech_client.cls', description='visit cls service')

@ -41,11 +41,7 @@ asr_online:
shift_ms: 40
sample_rate: 16000
sample_width: 2
vad_conf:
aggressiveness: 2
sample_rate: 16000
frame_duration_ms: 20
sample_width: 2
padding_ms: 200
padding_ratio: 0.9
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 20 # ms
shift_ms: 10 # ms

@ -0,0 +1,45 @@
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'conformer_online_multicn'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2

File diff suppressed because it is too large Load Diff

@ -0,0 +1,128 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
from paddlespeech.cli.log import logger
from paddlespeech.s2t.utils.utility import log_add
__all__ = ['CTCPrefixBeamSearch']
class CTCPrefixBeamSearch:
def __init__(self, config):
"""Implement the ctc prefix beam search
Args:
config (yacs.config.CfgNode): _description_
"""
self.config = config
self.reset()
@paddle.no_grad()
def search(self, ctc_probs, device, blank_id=0):
"""ctc prefix beam search method decode a chunk feature
Args:
xs (paddle.Tensor): feature data
ctc_probs (paddle.Tensor): the ctc probability of all the tokens
device (paddle.fluid.core_avx.Place): the feature host device, such as CUDAPlace(0).
blank_id (int, optional): the blank id in the vocab. Defaults to 0.
Returns:
list: the search result
"""
# decode
logger.info("start to ctc prefix search")
batch_size = 1
beam_size = self.config.beam_size
maxlen = ctc_probs.shape[0]
assert len(ctc_probs.shape) == 2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if self.cur_hyps is None:
self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in self.cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
self.cur_hyps = next_hyps[:beam_size]
self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps]
logger.info("ctc prefix search success")
return self.hyps
def get_one_best_hyps(self):
"""Return the one best result
Returns:
list: the one best result
"""
return [self.hyps[0][0]]
def get_hyps(self):
"""Return the search hyps
Returns:
list: return the search hyps
"""
return self.hyps
def reset(self):
"""Rest the search cache value
"""
self.cur_hyps = None
self.hyps = None
def finalize_search(self):
"""do nothing in ctc_prefix_beam_search
"""
pass

@ -12,24 +12,329 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import math
import os
import time
from typing import Optional
import numpy as np
import paddle
import yaml
from yacs.config import CfgNode
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.util import denorm
from paddlespeech.server.utils.util import get_chunks
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore
__all__ = ['TTSEngine']
# support online model
pretrained_models = {
# fastspeech2
"fastspeech2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip',
'md5':
'637d28a5e53aa60275612ba4393d5f22',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_76000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_cnndecoder_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip',
'md5':
'6eb28e22ace73e0ebe7845f86478f89f',
'config':
'cnndecoder.yaml',
'ckpt':
'snapshot_iter_153000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'ee5f0604e20091f0d495b6ec4618b90d',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
}
model_alias = {
# acoustic model
"fastspeech2":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2",
"fastspeech2_inference":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
# voc
"mb_melgan":
"paddlespeech.t2s.models.melgan:MelGANGenerator",
"mb_melgan_inference":
"paddlespeech.t2s.models.melgan:MelGANInference",
"hifigan":
"paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
"hifigan_inference":
"paddlespeech.t2s.models.hifigan:HiFiGANInference",
}
__all__ = ['TTSEngine']
class TTSServerExecutor(TTSExecutor):
def __init__(self):
def __init__(self, am_block, am_pad, voc_block, voc_pad):
super().__init__()
pass
self.am_block = am_block
self.am_pad = am_pad
self.voc_block = voc_block
self.voc_pad = voc_pad
def get_model_info(self,
field: str,
model_name: str,
ckpt: Optional[os.PathLike],
stat: Optional[os.PathLike]):
"""get model information
Args:
field (str): am or voc
model_name (str): model type, support fastspeech2, higigan, mb_melgan
ckpt (Optional[os.PathLike]): ckpt file
stat (Optional[os.PathLike]): stat file, including mean and standard deviation
Returns:
[module]: model module
[Tensor]: mean
[Tensor]: standard deviation
"""
model_class = dynamic_import(model_name, model_alias)
if field == "am":
odim = self.am_config.n_mels
model = model_class(
idim=self.vocab_size, odim=odim, **self.am_config["model"])
model.set_state_dict(paddle.load(ckpt)["main_params"])
elif field == "voc":
model = model_class(**self.voc_config["generator_params"])
model.set_state_dict(paddle.load(ckpt)["generator_params"])
model.remove_weight_norm()
else:
logger.error("Please set correct field, am or voc")
model.eval()
model_mu, model_std = np.load(stat)
model_mu = paddle.to_tensor(model_mu)
model_std = paddle.to_tensor(model_std)
return model, model_mu, model_std
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(
self,
am: str='fastspeech2_csmsc',
am_config: Optional[os.PathLike]=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
voc: str='mb_melgan_csmsc',
voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None,
lang: str='zh', ):
"""
Init model and other resources from a specific path.
"""
if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'):
logger.info('Models had been initialized.')
return
# am model info
am_tag = am + '-' + lang
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag)
self.am_res_path = am_res_path
self.am_config = os.path.join(am_res_path,
pretrained_models[am_tag]['config'])
self.am_ckpt = os.path.join(am_res_path,
pretrained_models[am_tag]['ckpt'])
self.am_stat = os.path.join(
am_res_path, pretrained_models[am_tag]['speech_stats'])
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
am_res_path, pretrained_models[am_tag]['phones_dict'])
print("self.phones_dict:", self.phones_dict)
logger.info(am_res_path)
logger.info(self.am_config)
logger.info(self.am_ckpt)
else:
self.am_config = os.path.abspath(am_config)
self.am_ckpt = os.path.abspath(am_ckpt)
self.am_stat = os.path.abspath(am_stat)
self.phones_dict = os.path.abspath(phones_dict)
self.am_res_path = os.path.dirname(os.path.abspath(self.am_config))
print("self.phones_dict:", self.phones_dict)
self.tones_dict = None
self.speaker_dict = None
# voc model info
voc_tag = voc + '-' + lang
if voc_ckpt is None or voc_config is None or voc_stat is None:
voc_res_path = self._get_pretrained_path(voc_tag)
self.voc_res_path = voc_res_path
self.voc_config = os.path.join(voc_res_path,
pretrained_models[voc_tag]['config'])
self.voc_ckpt = os.path.join(voc_res_path,
pretrained_models[voc_tag]['ckpt'])
self.voc_stat = os.path.join(
voc_res_path, pretrained_models[voc_tag]['speech_stats'])
logger.info(voc_res_path)
logger.info(self.voc_config)
logger.info(self.voc_ckpt)
else:
self.voc_config = os.path.abspath(voc_config)
self.voc_ckpt = os.path.abspath(voc_ckpt)
self.voc_stat = os.path.abspath(voc_stat)
self.voc_res_path = os.path.dirname(
os.path.abspath(self.voc_config))
# Init body.
with open(self.am_config) as f:
self.am_config = CfgNode(yaml.safe_load(f))
with open(self.voc_config) as f:
self.voc_config = CfgNode(yaml.safe_load(f))
with open(self.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
self.vocab_size = len(phn_id)
print("vocab_size:", self.vocab_size)
# frontend
if lang == 'zh':
self.frontend = Frontend(
phone_vocab_path=self.phones_dict,
tone_vocab_path=self.tones_dict)
elif lang == 'en':
self.frontend = English(phone_vocab_path=self.phones_dict)
print("frontend done!")
# am infer info
self.am_name = am[:am.rindex('_')]
if self.am_name == "fastspeech2_cnndecoder":
self.am_inference, self.am_mu, self.am_std = self.get_model_info(
"am", "fastspeech2", self.am_ckpt, self.am_stat)
else:
am, am_mu, am_std = self.get_model_info("am", self.am_name,
self.am_ckpt, self.am_stat)
am_normalizer = ZScore(am_mu, am_std)
am_inference_class = dynamic_import(self.am_name + '_inference',
model_alias)
self.am_inference = am_inference_class(am_normalizer, am)
self.am_inference.eval()
print("acoustic model done!")
# voc infer info
self.voc_name = voc[:voc.rindex('_')]
voc, voc_mu, voc_std = self.get_model_info("voc", self.voc_name,
self.voc_ckpt, self.voc_stat)
voc_normalizer = ZScore(voc_mu, voc_std)
voc_inference_class = dynamic_import(self.voc_name + '_inference',
model_alias)
self.voc_inference = voc_inference_class(voc_normalizer, voc)
self.voc_inference.eval()
print("voc done!")
def get_phone(self, sentence, lang, merge_sentences, get_tone_ids):
tone_ids = None
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
def depadding(self, data, chunk_num, chunk_id, block, pad, upsample):
"""
Streaming inference removes the result of pad inference
"""
front_pad = min(chunk_id * block, pad)
# first chunk
if chunk_id == 0:
data = data[:block * upsample]
# last chunk
elif chunk_id == chunk_num - 1:
data = data[front_pad * upsample:]
# middle chunk
else:
data = data[front_pad * upsample:(front_pad + block) * upsample]
return data
@paddle.no_grad()
def infer(
@ -37,16 +342,20 @@ class TTSServerExecutor(TTSExecutor):
text: str,
lang: str='zh',
am: str='fastspeech2_csmsc',
spk_id: int=0,
am_block: int=42,
am_pad: int=12,
voc_block: int=14,
voc_pad: int=14, ):
spk_id: int=0, ):
"""
Model inference and result stored in self.output.
"""
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_block = self.am_block
am_pad = self.am_pad
am_upsample = 1
voc_block = self.voc_block
voc_pad = self.voc_pad
voc_upsample = self.voc_config.n_shift
# first_flag 用于标记首包
first_flag = 1
get_tone_ids = False
merge_sentences = False
frontend_st = time.time()
@ -64,43 +373,100 @@ class TTSServerExecutor(TTSExecutor):
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
self.frontend_time = time.time() - frontend_st
frontend_et = time.time()
self.frontend_time = frontend_et - frontend_st
for i in range(len(phone_ids)):
am_st = time.time()
part_phone_ids = phone_ids[i]
# am
if am_name == 'speedyspeech':
part_tone_ids = tone_ids[i]
mel = self.am_inference(part_phone_ids, part_tone_ids)
# fastspeech2
voc_chunk_id = 0
# fastspeech2_csmsc
if am == "fastspeech2_csmsc":
# am
mel = self.am_inference(part_phone_ids)
if first_flag == 1:
first_am_et = time.time()
self.first_am_infer = first_am_et - frontend_et
# voc streaming
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
voc_chunk_num = len(mel_chunks)
voc_st = time.time()
for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.voc_inference(mel_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num, i,
voc_block, voc_pad, voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
self.first_response_time = first_voc_et - frontend_st
first_flag = 0
yield sub_wav
# fastspeech2_cnndecoder_csmsc
elif am == "fastspeech2_cnndecoder_csmsc":
# am
orig_hs, h_masks = self.am_inference.encoder_infer(
part_phone_ids)
# streaming voc chunk info
mel_len = orig_hs.shape[1]
voc_chunk_num = math.ceil(mel_len / self.voc_block)
start = 0
end = min(self.voc_block + self.voc_pad, mel_len)
# streaming am
hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am")
am_chunk_num = len(hss)
for i, hs in enumerate(hss):
before_outs, _ = self.am_inference.decoder(hs)
after_outs = before_outs + self.am_inference.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
normalized_mel = after_outs[0]
sub_mel = denorm(normalized_mel, self.am_mu, self.am_std)
sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block,
am_pad, am_upsample)
if i == 0:
mel_streaming = sub_mel
else:
mel_streaming = np.concatenate(
(mel_streaming, sub_mel), axis=0)
# streaming voc
# 当流式AM推理的mel帧数大于流式voc推理的chunk size开始进行流式voc 推理
while (mel_streaming.shape[0] >= end and
voc_chunk_id < voc_chunk_num):
if first_flag == 1:
first_am_et = time.time()
self.first_am_infer = first_am_et - frontend_et
voc_chunk = mel_streaming[start:end, :]
voc_chunk = paddle.to_tensor(voc_chunk)
sub_wav = self.voc_inference(voc_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num,
voc_chunk_id, voc_block,
voc_pad, voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
self.first_response_time = first_voc_et - frontend_st
first_flag = 0
yield sub_wav
voc_chunk_id += 1
start = max(0, voc_chunk_id * voc_block - voc_pad)
end = min((voc_chunk_id + 1) * voc_block + voc_pad,
mel_len)
else:
# multi speaker
if am_dataset in {"aishell3", "vctk"}:
mel = self.am_inference(
part_phone_ids, spk_id=paddle.to_tensor(spk_id))
else:
mel = self.am_inference(part_phone_ids)
am_et = time.time()
# voc streaming
voc_upsample = self.voc_config.n_shift
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
chunk_num = len(mel_chunks)
voc_st = time.time()
for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.voc_inference(mel_chunk)
front_pad = min(i * voc_block, voc_pad)
if i == 0:
sub_wav = sub_wav[:voc_block * voc_upsample]
elif i == chunk_num - 1:
sub_wav = sub_wav[front_pad * voc_upsample:]
else:
sub_wav = sub_wav[front_pad * voc_upsample:(
front_pad + voc_block) * voc_upsample]
yield sub_wav
logger.error(
"Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts."
)
self.final_response_time = time.time() - frontend_st
class TTSEngine(BaseEngine):
@ -113,14 +479,21 @@ class TTSEngine(BaseEngine):
def __init__(self, name=None):
"""Initialize TTS server engine
"""
super(TTSEngine, self).__init__()
super().__init__()
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config
assert "fastspeech2_csmsc" in config.am and (
config.voc == "hifigan_csmsc-zh" or config.voc == "mb_melgan_csmsc"
assert (
config.am == "fastspeech2_csmsc" or
config.am == "fastspeech2_cnndecoder_csmsc"
) and (
config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
config.voc_block > 0 and config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
try:
if self.config.device:
self.device = self.config.device
@ -135,6 +508,9 @@ class TTSEngine(BaseEngine):
(self.device))
return False
self.executor = TTSServerExecutor(config.am_block, config.am_pad,
config.voc_block, config.voc_pad)
try:
self.executor._init_from_path(
am=self.config.am,
@ -155,15 +531,42 @@ class TTSEngine(BaseEngine):
(self.device))
return False
self.am_block = self.config.am_block
self.am_pad = self.config.am_pad
self.voc_block = self.config.voc_block
self.voc_pad = self.config.voc_pad
logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device))
# warm up
try:
self.warm_up()
except Exception as e:
logger.error("Failed to warm up on tts engine.")
return False
return True
def warm_up(self):
"""warm up
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info(
"*******************************warm up ********************************"
)
for i in range(3):
for wav in self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, ):
logger.info(
f"The first response time of the {i} warm up: {self.executor.first_response_time} s"
)
break
logger.info(
"**********************************************************************"
)
def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
# Convert byte to text
if text_bese64:
@ -195,18 +598,14 @@ class TTSEngine(BaseEngine):
wav_base64: The base64 format of the synthesized audio.
"""
lang = self.config.lang
wav_list = []
for wav in self.executor.infer(
text=sentence,
lang=lang,
lang=self.config.lang,
am=self.config.am,
spk_id=spk_id,
am_block=self.am_block,
am_pad=self.am_pad,
voc_block=self.voc_block,
voc_pad=self.voc_pad):
spk_id=spk_id, ):
# wav type: <class 'numpy.ndarray'> float32, convert to pcm (base64)
wav = float2pcm(wav) # float32 to int16
wav_bytes = wav.tobytes() # to bytes
@ -216,5 +615,14 @@ class TTSEngine(BaseEngine):
yield wav_base64
wav_all = np.concatenate(wav_list, axis=0)
logger.info("The durations of audio is: {} s".format(
len(wav_all) / self.executor.am_config.fs))
duration = len(wav_all) / self.executor.am_config.fs
logger.info(f"sentence: {sentence}")
logger.info(f"The durations of audio is: {duration} s")
logger.info(
f"first response time: {self.executor.first_response_time} s")
logger.info(
f"final response time: {self.executor.final_response_time} s")
logger.info(f"RTF: {self.executor.final_response_time / duration}")
logger.info(
f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s,"
)

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -34,10 +34,9 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples)
# chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size != 0:
chunk_size = 85 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size!= 0:
padding_len_x = chunk_size - x_len % chunk_size
else:
padding_len_x = 0
@ -48,7 +47,6 @@ class ASRAudioHandler:
assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk)
for i in range(0, num_chunk):
start = i * chunk_size
end = start + chunk_size
@ -57,7 +55,11 @@ class ASRAudioHandler:
async def run(self, wavfile_path: str):
logging.info("send a message to the server")
# self.read_wave()
# send websocket handshake protocal
async with websockets.connect(self.url) as ws:
# server has already received handshake protocal
# client start to send the command
audio_info = json.dumps(
{
"name": "test.wav",
@ -78,7 +80,6 @@ class ASRAudioHandler:
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
result = msg
# finished
audio_info = json.dumps(
{
@ -91,10 +92,12 @@ class ASRAudioHandler:
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
# decode the bytes to str
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
return result
logging.info("final receive msg={}".format(msg))
result = msg
return result
def main(args):

@ -63,12 +63,12 @@ class ChunkBuffer(object):
the sample rate.
Yields Frames of the requested duration.
"""
audio = self.remained_audio + audio
self.remained_audio = b''
offset = 0
timestamp = 0.0
while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp,
self.window_sec)

@ -52,6 +52,10 @@ def get_chunks(data, block_size, pad_size, step):
Returns:
list: chunks list
"""
if block_size == -1:
return [data]
if step == "am":
data_len = data.shape[1]
elif step == "voc":

@ -13,12 +13,12 @@
# limitations under the License.
import json
import numpy as np
from fastapi import APIRouter
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio
@ -28,26 +28,29 @@ router = APIRouter()
@router.websocket('/ws/asr')
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
connection_handler = None
# init buffer
# each websocekt connection has its own chunk buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer(
window_n=7,
shift_n=4,
window_ms=20,
shift_ms=10,
sample_rate=chunk_buffer_conf['sample_rate'],
sample_width=chunk_buffer_conf['sample_width'])
window_n=chunk_buffer_conf.window_n,
shift_n=chunk_buffer_conf.shift_n,
window_ms=chunk_buffer_conf.window_ms,
shift_ms=chunk_buffer_conf.shift_ms,
sample_rate=chunk_buffer_conf.sample_rate,
sample_width=chunk_buffer_conf.sample_width)
# init vad
vad_conf = asr_engine.config.vad_conf
vad = VADAudio(
aggressiveness=vad_conf['aggressiveness'],
rate=vad_conf['sample_rate'],
frame_duration_ms=vad_conf['frame_duration_ms'])
vad_conf = asr_engine.config.get('vad_conf', None)
if vad_conf:
vad = VADAudio(
aggressiveness=vad_conf['aggressiveness'],
rate=vad_conf['sample_rate'],
frame_duration_ms=vad_conf['frame_duration_ms'])
try:
while True:
@ -64,13 +67,21 @@ async def websocket_endpoint(websocket: WebSocket):
if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here
# create the instance to process the audio
connection_handler = PaddleASRConnectionHanddler(asr_engine)
await websocket.send_json(resp)
elif message['signal'] == 'end':
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
# reset single engine for an new connection
asr_engine.reset()
resp = {"status": "ok", "signal": "finished"}
connection_handler.decode(is_finished=True)
connection_handler.rescoring()
asr_results = connection_handler.get_result()
connection_handler.reset()
resp = {
"status": "ok",
"signal": "finished",
'asr_results': asr_results
}
await websocket.send_json(resp)
break
else:
@ -79,21 +90,11 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message:
message = message["bytes"]
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_results = ""
frames = chunk_buffer.frame_generator(message)
for frame in frames:
samples = np.frombuffer(frame.bytes, dtype=np.int16)
sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate)
asr_engine.run(x_chunk, x_chunk_lens)
asr_results = asr_engine.postprocess()
connection_handler.extract_feat(message)
connection_handler.decode(is_finished=False)
asr_results = connection_handler.get_result()
asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results}
await websocket.send_json(resp)
except WebSocketDisconnect:
pass

@ -14,6 +14,7 @@
import argparse
from pathlib import Path
import paddle
import soundfile as sf
from timer import timer
@ -101,21 +102,35 @@ def parse_args():
# only inference for models trained with csmsc now
def main():
args = parse_args()
paddle.set_device(args.device)
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# am_predictor
am_predictor = get_predictor(args, filed='am')
am_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + ".pdmodel",
params_file=args.am + ".pdiparams",
device=args.device)
# model: {model_name}_{dataset}
am_dataset = args.am[args.am.rindex('_') + 1:]
# voc_predictor
voc_predictor = get_predictor(args, filed='voc')
voc_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.voc + ".pdmodel",
params_file=args.voc + ".pdiparams",
device=args.device)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
merge_sentences = True
fs = 24000 if am_dataset != 'ljspeech' else 22050
@ -123,11 +138,13 @@ def main():
for utt_id, sentence in sentences[:3]:
with timer() as t:
am_output_data = get_am_output(
args,
input=sentence,
am_predictor=am_predictor,
am=args.am,
frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences,
input=sentence)
speaker_dict=args.speaker_dict, )
wav = get_voc_output(
voc_predictor=voc_predictor, input=am_output_data)
speed = wav.size / t.elapse
@ -143,11 +160,13 @@ def main():
for utt_id, sentence in sentences:
with timer() as t:
am_output_data = get_am_output(
args,
input=sentence,
am_predictor=am_predictor,
am=args.am,
frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences,
input=sentence)
speaker_dict=args.speaker_dict, )
wav = get_voc_output(
voc_predictor=voc_predictor, input=am_output_data)

@ -15,6 +15,7 @@ import argparse
from pathlib import Path
import numpy as np
import paddle
import soundfile as sf
from timer import timer
@ -25,7 +26,6 @@ from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_predictor
from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_predictor
from paddlespeech.t2s.exps.syn_utils import get_voc_output
from paddlespeech.t2s.utils import str2bool
@ -101,23 +101,47 @@ def parse_args():
# only inference for models trained with csmsc now
def main():
args = parse_args()
paddle.set_device(args.device)
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# am_predictor
am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor = get_streaming_am_predictor(
args)
am_encoder_infer_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_encoder_infer" + ".pdmodel",
params_file=args.am + "_am_encoder_infer" + ".pdiparams",
device=args.device)
am_decoder_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_decoder" + ".pdmodel",
params_file=args.am + "_am_decoder" + ".pdiparams",
device=args.device)
am_postnet_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_postnet" + ".pdmodel",
params_file=args.am + "_am_postnet" + ".pdiparams",
device=args.device)
am_mu, am_std = np.load(args.am_stat)
# model: {model_name}_{dataset}
am_dataset = args.am[args.am.rindex('_') + 1:]
# voc_predictor
voc_predictor = get_predictor(args, filed='voc')
voc_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.voc + ".pdmodel",
params_file=args.voc + ".pdiparams",
device=args.device)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
merge_sentences = True
@ -126,13 +150,13 @@ def main():
for utt_id, sentence in sentences[:3]:
with timer() as t:
normalized_mel = get_streaming_am_output(
args,
input=sentence,
am_encoder_infer_predictor=am_encoder_infer_predictor,
am_decoder_predictor=am_decoder_predictor,
am_postnet_predictor=am_postnet_predictor,
frontend=frontend,
merge_sentences=merge_sentences,
input=sentence)
lang=args.lang,
merge_sentences=merge_sentences, )
mel = denorm(normalized_mel, am_mu, am_std)
wav = get_voc_output(voc_predictor=voc_predictor, input=mel)
speed = wav.size / t.elapse

@ -16,6 +16,7 @@ from pathlib import Path
import jsonlines
import numpy as np
import paddle
import soundfile as sf
from timer import timer
@ -25,12 +26,13 @@ from paddlespeech.t2s.utils import str2bool
def ort_predict(args):
# construct dataset for evaluation
with jsonlines.open(args.test_metadata, 'r') as reader:
test_metadata = list(reader)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset)
test_dataset = get_test_dataset(test_metadata=test_metadata, am=args.am)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
@ -38,10 +40,18 @@ def ort_predict(args):
fs = 24000 if am_dataset != 'ljspeech' else 22050
# am
am_sess = get_sess(args, filed='am')
am_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# vocoder
voc_sess = get_sess(args, filed='voc')
voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# am warmup
for T in [27, 38, 54]:
@ -135,6 +145,8 @@ def parse_args():
def main():
args = parse_args()
paddle.set_device(args.device)
ort_predict(args)

@ -15,6 +15,7 @@ import argparse
from pathlib import Path
import numpy as np
import paddle
import soundfile as sf
from timer import timer
@ -27,21 +28,31 @@ from paddlespeech.t2s.utils import str2bool
def ort_predict(args):
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
fs = 24000 if am_dataset != 'ljspeech' else 22050
# am
am_sess = get_sess(args, filed='am')
am_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# vocoder
voc_sess = get_sess(args, filed='voc')
voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# frontend warmup
# Loading model cost 0.5+ seconds
@ -168,6 +179,8 @@ def parse_args():
def main():
args = parse_args()
paddle.set_device(args.device)
ort_predict(args)

@ -15,6 +15,7 @@ import argparse
from pathlib import Path
import numpy as np
import paddle
import soundfile as sf
from timer import timer
@ -23,30 +24,50 @@ from paddlespeech.t2s.exps.syn_utils import get_chunks
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_sess
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_sess
from paddlespeech.t2s.utils import str2bool
def ort_predict(args):
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
fs = 24000 if am_dataset != 'ljspeech' else 22050
# am
am_encoder_infer_sess, am_decoder_sess, am_postnet_sess = get_streaming_am_sess(
args)
# streaming acoustic model
am_encoder_infer_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + "_am_encoder_infer" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_decoder_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + "_am_decoder" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_postnet_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + "_am_postnet" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_mu, am_std = np.load(args.am_stat)
# vocoder
voc_sess = get_sess(args, filed='voc')
voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# frontend warmup
# Loading model cost 0.5+ seconds
@ -226,6 +247,8 @@ def parse_args():
def main():
args = parse_args()
paddle.set_device(args.device)
ort_predict(args)

@ -14,6 +14,10 @@
import math
import os
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
import numpy as np
import onnxruntime as ort
@ -21,6 +25,7 @@ import paddle
from paddle import inference
from paddle import jit
from paddle.static import InputSpec
from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.datasets.data_table import DataTable
@ -70,7 +75,7 @@ def denorm(data, mean, std):
return data * std + mean
def get_chunks(data, chunk_size, pad_size):
def get_chunks(data, chunk_size: int, pad_size: int):
data_len = data.shape[1]
chunks = []
n = math.ceil(data_len / chunk_size)
@ -82,28 +87,34 @@ def get_chunks(data, chunk_size, pad_size):
# input
def get_sentences(args):
def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
# construct dataset for evaluation
sentences = []
with open(args.text, 'rt') as f:
with open(text_file, 'rt') as f:
for line in f:
items = line.strip().split()
utt_id = items[0]
if 'lang' in args and args.lang == 'zh':
if lang == 'zh':
sentence = "".join(items[1:])
elif 'lang' in args and args.lang == 'en':
elif lang == 'en':
sentence = " ".join(items[1:])
sentences.append((utt_id, sentence))
return sentences
def get_test_dataset(args, test_metadata, am_name, am_dataset):
def get_test_dataset(test_metadata: List[Dict[str, Any]],
am: str,
speaker_dict: Optional[os.PathLike]=None,
voice_cloning: bool=False):
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2':
fields = ["utt_id", "text"]
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
print("multiple speaker fastspeech2!")
fields += ["spk_id"]
elif 'voice_cloning' in args and args.voice_cloning:
elif voice_cloning:
print("voice cloning!")
fields += ["spk_emb"]
else:
@ -112,7 +123,7 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset):
fields = ["utt_id", "phones", "tones"]
elif am_name == 'tacotron2':
fields = ["utt_id", "text"]
if 'voice_cloning' in args and args.voice_cloning:
if voice_cloning:
print("voice cloning!")
fields += ["spk_emb"]
@ -121,12 +132,14 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset):
# frontend
def get_frontend(args):
if 'lang' in args and args.lang == 'zh':
def get_frontend(lang: str='zh',
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None):
if lang == 'zh':
frontend = Frontend(
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict)
elif 'lang' in args and args.lang == 'en':
frontend = English(phone_vocab_path=args.phones_dict)
phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
elif lang == 'en':
frontend = English(phone_vocab_path=phones_dict)
else:
print("wrong lang!")
print("frontend done!")
@ -134,30 +147,37 @@ def get_frontend(args):
# dygraph
def get_am_inference(args, am_config):
with open(args.phones_dict, "r") as f:
def get_am_inference(
am: str='fastspeech2_csmsc',
am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None, ):
with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
tone_size = None
if 'tones_dict' in args and args.tones_dict:
with open(args.tones_dict, "r") as f:
if tones_dict is not None:
with open(tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id)
print("tone_size:", tone_size)
spk_num = None
if 'speaker_dict' in args and args.speaker_dict:
with open(args.speaker_dict, 'rt') as f:
if speaker_dict is not None:
with open(speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
print("spk_num:", spk_num)
odim = am_config.n_mels
# model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias)
@ -174,34 +194,38 @@ def get_am_inference(args, am_config):
elif am_name == 'tacotron2':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
am.set_state_dict(paddle.load(args.am_ckpt)["main_params"])
am.set_state_dict(paddle.load(am_ckpt)["main_params"])
am.eval()
am_mu, am_std = np.load(args.am_stat)
am_mu, am_std = np.load(am_stat)
am_mu = paddle.to_tensor(am_mu)
am_std = paddle.to_tensor(am_std)
am_normalizer = ZScore(am_mu, am_std)
am_inference = am_inference_class(am_normalizer, am)
am_inference.eval()
print("acoustic model done!")
return am_inference, am_name, am_dataset
return am_inference
def get_voc_inference(args, voc_config):
def get_voc_inference(
voc: str='pwgan_csmsc',
voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None, ):
# model: {model_name}_{dataset}
voc_name = args.voc[:args.voc.rindex('_')]
voc_name = voc[:voc.rindex('_')]
voc_class = dynamic_import(voc_name, model_alias)
voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
if voc_name != 'wavernn':
voc = voc_class(**voc_config["generator_params"])
voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"])
voc.set_state_dict(paddle.load(voc_ckpt)["generator_params"])
voc.remove_weight_norm()
voc.eval()
else:
voc = voc_class(**voc_config["model"])
voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"])
voc.set_state_dict(paddle.load(voc_ckpt)["main_params"])
voc.eval()
voc_mu, voc_std = np.load(args.voc_stat)
voc_mu, voc_std = np.load(voc_stat)
voc_mu = paddle.to_tensor(voc_mu)
voc_std = paddle.to_tensor(voc_std)
voc_normalizer = ZScore(voc_mu, voc_std)
@ -211,10 +235,16 @@ def get_voc_inference(args, voc_config):
return voc_inference
# to static
def am_to_static(args, am_inference, am_name, am_dataset):
# dygraph to static graph
def am_to_static(am_inference,
am: str='fastspeech2_csmsc',
inference_dir=Optional[os.PathLike],
speaker_dict: Optional[os.PathLike]=None):
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static(
am_inference,
input_spec=[
@ -226,7 +256,7 @@ def am_to_static(args, am_inference, am_name, am_dataset):
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
elif am_name == 'speedyspeech':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static(
am_inference,
input_spec=[
@ -247,56 +277,64 @@ def am_to_static(args, am_inference, am_name, am_dataset):
am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am))
am_inference = paddle.jit.load(os.path.join(args.inference_dir, args.am))
paddle.jit.save(am_inference, os.path.join(inference_dir, am))
am_inference = paddle.jit.load(os.path.join(inference_dir, am))
return am_inference
def voc_to_static(args, voc_inference):
def voc_to_static(voc_inference,
voc: str='pwgan_csmsc',
inference_dir=Optional[os.PathLike]):
voc_inference = jit.to_static(
voc_inference, input_spec=[
InputSpec([-1, 80], dtype=paddle.float32),
])
paddle.jit.save(voc_inference, os.path.join(args.inference_dir, args.voc))
voc_inference = paddle.jit.load(os.path.join(args.inference_dir, args.voc))
paddle.jit.save(voc_inference, os.path.join(inference_dir, voc))
voc_inference = paddle.jit.load(os.path.join(inference_dir, voc))
return voc_inference
# inference
def get_predictor(args, filed='am'):
full_name = ''
if filed == 'am':
full_name = args.am
elif filed == 'voc':
full_name = args.voc
def get_predictor(model_dir: Optional[os.PathLike]=None,
model_file: Optional[os.PathLike]=None,
params_file: Optional[os.PathLike]=None,
device: str='cpu'):
config = inference.Config(
str(Path(args.inference_dir) / (full_name + ".pdmodel")),
str(Path(args.inference_dir) / (full_name + ".pdiparams")))
if args.device == "gpu":
str(Path(model_dir) / model_file), str(Path(model_dir) / params_file))
if device == "gpu":
config.enable_use_gpu(100, 0)
elif args.device == "cpu":
elif device == "cpu":
config.disable_gpu()
config.enable_memory_optim()
predictor = inference.create_predictor(config)
return predictor
def get_am_output(args, am_predictor, frontend, merge_sentences, input):
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
def get_am_output(
input: str,
am_predictor,
am,
frontend,
lang: str='zh',
merge_sentences: bool=True,
speaker_dict: Optional[os.PathLike]=None,
spk_id: int=0, ):
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_input_names = am_predictor.get_input_names()
get_tone_ids = False
get_spk_id = False
if am_name == 'speedyspeech':
get_tone_ids = True
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict:
if am_dataset in {"aishell3", "vctk"} and speaker_dict:
get_spk_id = True
spk_id = np.array([args.spk_id])
if args.lang == 'zh':
spk_id = np.array([spk_id])
if lang == 'zh':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
elif args.lang == 'en':
elif lang == 'en':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
@ -338,50 +376,6 @@ def get_voc_output(voc_predictor, input):
return wav
# streaming am
def get_streaming_am_predictor(args):
full_name = args.am
am_encoder_infer_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_encoder_infer" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_encoder_infer" + ".pdiparams")))
am_decoder_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_decoder" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_decoder" + ".pdiparams")))
am_postnet_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_postnet" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_postnet" + ".pdiparams")))
if args.device == "gpu":
am_encoder_infer_config.enable_use_gpu(100, 0)
am_decoder_config.enable_use_gpu(100, 0)
am_postnet_config.enable_use_gpu(100, 0)
elif args.device == "cpu":
am_encoder_infer_config.disable_gpu()
am_decoder_config.disable_gpu()
am_postnet_config.disable_gpu()
am_encoder_infer_config.enable_memory_optim()
am_decoder_config.enable_memory_optim()
am_postnet_config.enable_memory_optim()
am_encoder_infer_predictor = inference.create_predictor(
am_encoder_infer_config)
am_decoder_predictor = inference.create_predictor(am_decoder_config)
am_postnet_predictor = inference.create_predictor(am_postnet_config)
return am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor
def get_am_sublayer_output(am_sublayer_predictor, input):
am_sublayer_input_names = am_sublayer_predictor.get_input_names()
input_handle = am_sublayer_predictor.get_input_handle(
@ -397,11 +391,15 @@ def get_am_sublayer_output(am_sublayer_predictor, input):
return am_sublayer_output
def get_streaming_am_output(args, am_encoder_infer_predictor,
am_decoder_predictor, am_postnet_predictor,
frontend, merge_sentences, input):
def get_streaming_am_output(input: str,
am_encoder_infer_predictor,
am_decoder_predictor,
am_postnet_predictor,
frontend,
lang: str='zh',
merge_sentences: bool=True):
get_tone_ids = False
if args.lang == 'zh':
if lang == 'zh':
input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
@ -423,58 +421,27 @@ def get_streaming_am_output(args, am_encoder_infer_predictor,
return normalized_mel
def get_sess(args, filed='am'):
full_name = ''
if filed == 'am':
full_name = args.am
elif filed == 'voc':
full_name = args.voc
model_dir = str(Path(args.inference_dir) / (full_name + ".onnx"))
# onnx
def get_sess(model_dir: Optional[os.PathLike]=None,
model_file: Optional[os.PathLike]=None,
device: str='cpu',
cpu_threads: int=1,
use_trt: bool=False):
model_dir = str(Path(model_dir) / model_file)
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if args.device == "gpu":
if device == "gpu":
# fastspeech2/mb_melgan can't use trt now!
if args.use_trt:
if use_trt:
providers = ['TensorrtExecutionProvider']
else:
providers = ['CUDAExecutionProvider']
elif args.device == "cpu":
elif device == "cpu":
providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = args.cpu_threads
sess_options.intra_op_num_threads = cpu_threads
sess = ort.InferenceSession(
model_dir, providers=providers, sess_options=sess_options)
return sess
# streaming am
def get_streaming_am_sess(args):
full_name = args.am
am_encoder_infer_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_encoder_infer" + ".onnx"))
am_decoder_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_decoder" + ".onnx"))
am_postnet_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_postnet" + ".onnx"))
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if args.device == "gpu":
# fastspeech2/mb_melgan can't use trt now!
if args.use_trt:
providers = ['TensorrtExecutionProvider']
else:
providers = ['CUDAExecutionProvider']
elif args.device == "cpu":
providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = args.cpu_threads
am_encoder_infer_sess = ort.InferenceSession(
am_encoder_infer_model_dir,
providers=providers,
sess_options=sess_options)
am_decoder_sess = ort.InferenceSession(
am_decoder_model_dir, providers=providers, sess_options=sess_options)
am_postnet_sess = ort.InferenceSession(
am_postnet_model_dir, providers=providers, sess_options=sess_options)
return am_encoder_infer_sess, am_decoder_sess, am_postnet_sess

@ -50,11 +50,29 @@ def evaluate(args):
print(voc_config)
# acoustic model
am_inference, am_name, am_dataset = get_am_inference(args, am_config)
test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict)
test_dataset = get_test_dataset(
test_metadata=test_metadata,
am=args.am,
speaker_dict=args.speaker_dict,
voice_cloning=args.voice_cloning)
# vocoder
voc_inference = get_voc_inference(args, voc_config)
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

@ -42,24 +42,48 @@ def evaluate(args):
print(am_config)
print(voc_config)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# acoustic model
am_inference, am_name, am_dataset = get_am_inference(args, am_config)
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict)
# vocoder
voc_inference = get_voc_inference(args, voc_config)
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
# whether dygraph to static
if args.inference_dir:
# acoustic model
am_inference = am_to_static(args, am_inference, am_name, am_dataset)
am_inference = am_to_static(
am_inference=am_inference,
am=args.am,
inference_dir=args.inference_dir,
speaker_dict=args.speaker_dict)
# vocoder
voc_inference = voc_to_static(args, voc_inference)
voc_inference = voc_to_static(
voc_inference=voc_inference,
voc=args.voc,
inference_dir=args.inference_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

@ -49,10 +49,13 @@ def evaluate(args):
print(am_config)
print(voc_config)
sentences = get_sentences(args)
sentences = get_sentences(text_file=args.text, lang=args.lang)
# frontend
frontend = get_frontend(args)
frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
@ -60,7 +63,6 @@ def evaluate(args):
print("vocab_size:", vocab_size)
# acoustic model, only support fastspeech2 here now!
# am_inference, am_name, am_dataset = get_am_inference(args, am_config)
# model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
@ -80,7 +82,11 @@ def evaluate(args):
am_postnet = am.postnet
# vocoder
voc_inference = get_voc_inference(args, voc_config)
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
# whether dygraph to static
if args.inference_dir:
@ -115,7 +121,10 @@ def evaluate(args):
os.path.join(args.inference_dir, args.am + "_am_postnet"))
# vocoder
voc_inference = voc_to_static(args, voc_inference)
voc_inference = voc_to_static(
voc_inference=voc_inference,
voc=args.voc,
inference_dir=args.inference_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

@ -66,10 +66,19 @@ def voice_cloning(args):
print("frontend done!")
# acoustic model
am_inference, *_ = get_am_inference(args, am_config)
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict)
# vocoder
voc_inference = get_voc_inference(args, voc_config)
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

@ -58,8 +58,7 @@ def main():
else:
print("ngpu should >= 0 !")
model = WaveRNN(
hop_length=config.n_shift, sample_rate=config.fs, **config["model"])
model = WaveRNN(**config["model"])
state_dict = paddle.load(args.checkpoint)
model.set_state_dict(state_dict["main_params"])

@ -91,3 +91,199 @@ class LogSoftmaxWrapper(nn.Layer):
predictions = F.log_softmax(predictions, axis=1)
loss = self.criterion(predictions, targets) / targets.sum()
return loss
class NCELoss(nn.Layer):
"""Noise Contrastive Estimation loss funtion
Noise Contrastive Estimation (NCE) is an approximation method that is used to
work around the huge computational cost of large softmax layer.
The basic idea is to convert the prediction problem into classification problem
at training stage. It has been proved that these two criterions converges to
the same minimal point as long as noise distribution is close enough to real one.
NCE bridges the gap between generative models and discriminative models,
rather than simply speedup the softmax layer.
With NCE, you can turn almost anything into posterior with less effort (I think).
Refs:
NCEhttp://www.cs.helsinki.fi/u/ahyvarin/papers/Gutmann10AISTATS.pdf
Thanks: https://github.com/mingen-pan/easy-to-use-NCE-RNN-for-Pytorch/blob/master/nce.py
Examples:
Q = Q_from_tokens(output_dim)
NCELoss(Q)
"""
def __init__(self, Q, noise_ratio=100, Z_offset=9.5):
"""Noise Contrastive Estimation loss funtion
Args:
Q (tensor): prior model, uniform or guassian
noise_ratio (int, optional): noise sampling times. Defaults to 100.
Z_offset (float, optional): scale of post processing the score. Defaults to 9.5.
"""
super(NCELoss, self).__init__()
assert type(noise_ratio) is int
self.Q = paddle.to_tensor(Q, stop_gradient=False)
self.N = self.Q.shape[0]
self.K = noise_ratio
self.Z_offset = Z_offset
def forward(self, output, target):
"""Forward inference
Args:
output (tensor): the model output, which is the input of loss function
"""
output = paddle.reshape(output, [-1, self.N])
B = output.shape[0]
noise_idx = self.get_noise(B)
idx = self.get_combined_idx(target, noise_idx)
P_target, P_noise = self.get_prob(idx, output, sep_target=True)
Q_target, Q_noise = self.get_Q(idx)
loss = self.nce_loss(P_target, P_noise, Q_noise, Q_target)
return loss.mean()
def get_Q(self, idx, sep_target=True):
"""Get prior model of batchsize data
"""
idx_size = idx.size
prob_model = paddle.to_tensor(
self.Q.numpy()[paddle.reshape(idx, [-1]).numpy()])
prob_model = paddle.reshape(prob_model, [idx.shape[0], idx.shape[1]])
if sep_target:
return prob_model[:, 0], prob_model[:, 1:]
else:
return prob_model
def get_prob(self, idx, scores, sep_target=True):
"""Post processing the score of post model(output of nn) of batchsize data
"""
scores = self.get_scores(idx, scores)
scale = paddle.to_tensor([self.Z_offset], dtype='float64')
scores = paddle.add(scores, -scale)
prob = paddle.exp(scores)
if sep_target:
return prob[:, 0], prob[:, 1:]
else:
return prob
def get_scores(self, idx, scores):
"""Get the score of post model(output of nn) of batchsize data
"""
B, N = scores.shape
K = idx.shape[1]
idx_increment = paddle.to_tensor(
N * paddle.reshape(paddle.arange(B), [B, 1]) * paddle.ones([1, K]),
dtype="int64",
stop_gradient=False)
new_idx = idx_increment + idx
new_scores = paddle.index_select(
paddle.reshape(scores, [-1]), paddle.reshape(new_idx, [-1]))
return paddle.reshape(new_scores, [B, K])
def get_noise(self, batch_size, uniform=True):
"""Select noise sample
"""
if uniform:
noise = np.random.randint(self.N, size=self.K * batch_size)
else:
noise = np.random.choice(
self.N, self.K * batch_size, replace=True, p=self.Q.data)
noise = paddle.to_tensor(noise, dtype='int64', stop_gradient=False)
noise_idx = paddle.reshape(noise, [batch_size, self.K])
return noise_idx
def get_combined_idx(self, target_idx, noise_idx):
"""Combined target and noise
"""
target_idx = paddle.reshape(target_idx, [-1, 1])
return paddle.concat((target_idx, noise_idx), 1)
def nce_loss(self, prob_model, prob_noise_in_model, prob_noise,
prob_target_in_noise):
"""Combined the loss of target and noise
"""
def safe_log(tensor):
"""Safe log
"""
EPSILON = 1e-10
return paddle.log(EPSILON + tensor)
model_loss = safe_log(prob_model /
(prob_model + self.K * prob_target_in_noise))
model_loss = paddle.reshape(model_loss, [-1])
noise_loss = paddle.sum(
safe_log((self.K * prob_noise) /
(prob_noise_in_model + self.K * prob_noise)), -1)
noise_loss = paddle.reshape(noise_loss, [-1])
loss = -(model_loss + noise_loss)
return loss
class FocalLoss(nn.Layer):
"""This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=-100):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.size_average = size_average
self.ce = nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction="none")
def forward(self, outputs, targets):
"""Forword inference.
Args:
outputs: input tensor
target: target label tensor
"""
ce_loss = self.ce(outputs, targets)
pt = paddle.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt)**self.gamma * ce_loss
if self.size_average:
return focal_loss.mean()
else:
return focal_loss.sum()
if __name__ == "__main__":
import numpy as np
from paddlespeech.vector.utils.vector_utils import Q_from_tokens
paddle.set_device("cpu")
input_data = paddle.uniform([5, 100], dtype="float64")
label_data = np.random.randint(0, 100, size=(5)).astype(np.int64)
input = paddle.to_tensor(input_data)
label = paddle.to_tensor(label_data)
loss1 = FocalLoss()
loss = loss1.forward(input, label)
print("loss: %.5f" % (loss))
Q = Q_from_tokens(100)
loss2 = NCELoss(Q)
loss = loss2.forward(input, label)
print("loss: %.5f" % (loss))

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def get_chunks(seg_dur, audio_id, audio_duration):
@ -30,3 +31,11 @@ def get_chunks(seg_dur, audio_id, audio_duration):
for i in range(num_chunks)
]
return chunk_lst
def Q_from_tokens(token_num):
"""Get prior model, data from uniform, would support others(guassian) in future
"""
freq = [1] * token_num
Q = paddle.to_tensor(freq, dtype='float64')
return Q / Q.sum()

@ -63,7 +63,8 @@ include(libsndfile)
# include(boost) # not work
set(boost_SOURCE_DIR ${fc_patch}/boost-src)
set(BOOST_ROOT ${boost_SOURCE_DIR})
# #find_package(boost REQUIRED PATHS ${BOOST_ROOT})
include_directories(${boost_SOURCE_DIR})
link_directories(${boost_SOURCE_DIR}/stage/lib)
# Eigen
include(eigen)
@ -141,4 +142,4 @@ set(DEPS ${DEPS}
set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx)
add_subdirectory(speechx)
add_subdirectory(examples)
add_subdirectory(examples)

@ -2,4 +2,5 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat)
add_subdirectory(nnet)
add_subdirectory(decoder)
add_subdirectory(decoder)
add_subdirectory(websocket)

@ -1,6 +1,6 @@
# This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../../
SPEECHX_ROOT=$PWD/../../..
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat:$SPEECHX_EXAMPLES/ds2_ol/websocket
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN

@ -87,7 +87,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--params_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$vocb_dir/vocab.txt \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result
@ -102,7 +102,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--params_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm \
@ -129,7 +129,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
wfst-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \
--params_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \

@ -0,0 +1,37 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char
# output
aishell_wav_scp=aishell_test.scp
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
export GLOG_logtostderr=1
# websocket client
websocket_client_main \
--wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.36

@ -0,0 +1,66 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
# output
aishell_wav_scp=aishell_test.scp
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
if [ ! -d $ckpt_dir ]; then
mkdir -p $ckpt_dir
wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv $ckpt_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir
fi
export GLOG_logtostderr=1
# 3. gen cmvn
cmvn=$PWD/cmvn.ark
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
text=$data/test/text
graph_dir=./aishell_graph
if [ ! -d $graph_dir ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
unzip aishell_graph.zip
fi
# 5. test websocket server
websocket_server_main \
--cmvn_file=$cmvn \
--model_path=$model_dir/avg_1.jit.pdmodel \
--streaming_chunk=0.1 \
--convert2PCM32=true \
--params_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \
--acoustic_scale=1.2

@ -17,3 +17,6 @@ add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(recognizer_test_main ${CMAKE_CURRENT_SOURCE_DIR}/recognizer_test_main.cc)
target_include_directories(recognizer_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(recognizer_test_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})

@ -34,12 +34,10 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
@ -58,12 +56,11 @@ int main(int argc, char* argv[]) {
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path;
std::string model_path = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "model path: " << model_graph;
LOG(INFO) << "model path: " << model_path;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
@ -76,10 +73,9 @@ int main(int argc, char* argv[]) {
ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.model_path = model_path;
model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
@ -125,7 +121,6 @@ int main(int argc, char* argv[]) {
if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride;
int32 end = start + chunk_size;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);

@ -0,0 +1,85 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/recognizer.h"
#include "decoder/param.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
ppspeech::Recognizer recognizer(resource);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int sample_rate = 16000;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
int32 num_done = 0, num_err = 0;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
recognizer.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer.SetFinished();
}
recognizer.Decode();
sample_offset += cur_chunk_size;
}
std::string result;
result = recognizer.GetFinalResult();
recognizer.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
}

@ -73,9 +73,9 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary;
} catch (simdjson::simdjson_error& err) {
LOG(ERR) << err.what();
LOG(ERROR) << err.what();
}
return 0;
}
}

@ -32,7 +32,6 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
@ -66,7 +65,8 @@ int main(int argc, char* argv[]) {
std::unique_ptr<ppspeech::FrontendInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram)));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn));
ppspeech::FeatureCacheOptions feat_cache_opts;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000;

@ -0,0 +1,10 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc)
target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_server_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS})
add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc)
target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_client_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS})

@ -0,0 +1,82 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "websocket/websocket_client.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(host, "127.0.0.1", "host of websocket server");
DEFINE_int32(port, 201314, "port of websocket server");
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size");
using kaldi::int16;
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::WebSocketClient client(FLAGS_host, FLAGS_port);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
const int sample_rate = 16000;
const float streaming_chunk = FLAGS_streaming_chunk;
const int chunk_sample_size = streaming_chunk * sample_rate;
for (; !wav_reader.Done(); wav_reader.Next()) {
client.SendStartSignal();
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
CHECK_EQ(wave_data.SampFreq(), sample_rate);
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
const int tot_samples = waveform.Dim();
int sample_offset = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<int16> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = static_cast<int16>(waveform(sample_offset + i));
}
client.SendBinaryData(wav_chunk.data(),
wav_chunk.size() * sizeof(int16));
sample_offset += cur_chunk_size;
LOG(INFO) << "Send " << cur_chunk_size << " samples";
std::this_thread::sleep_for(
std::chrono::milliseconds(static_cast<int>(1 * 1000)));
if (cur_chunk_size < chunk_sample_size) {
client.SendEndSignal();
}
}
while (!client.Done()) {
}
std::string result = client.GetResult();
LOG(INFO) << "utt: " << utt << " " << result;
client.Join();
return 0;
}
return 0;
}

@ -0,0 +1,30 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "websocket/websocket_server.h"
#include "decoder/param.h"
DEFINE_int32(port, 201314, "websocket listening port");
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
ppspeech::WebSocketServer server(FLAGS_port, resource);
LOG(INFO) << "Listening at port " << FLAGS_port;
server.Start();
return 0;
}

@ -30,4 +30,10 @@ include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/decoder
)
add_subdirectory(decoder)
add_subdirectory(decoder)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/websocket
)
add_subdirectory(websocket)

@ -28,8 +28,10 @@
#include <sstream>
#include <stack>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "base/basic_types.h"

@ -7,5 +7,6 @@ add_library(decoder STATIC
ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp
ctc_tlg_decoder.cc
recognizer.cc
)
target_link_libraries(decoder PUBLIC kenlm utils fst)
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder)

@ -33,7 +33,6 @@ void TLGDecoder::InitDecoder() {
void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (!decodable->IsLastFrame(frame_decoded_size_)) {
LOG(INFO) << "num frame decode: " << frame_decoded_size_;
AdvanceDecoding(decodable.get());
}
}
@ -63,4 +62,4 @@ std::string TLGDecoder::GetFinalBestPath() {
}
return words;
}
}
}

@ -0,0 +1,94 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
DEFINE_string(cmvn_file, "", "read cmvn");
DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size");
DEFINE_bool(convert2PCM32, true, "audio convert to pcm32");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(params_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
namespace ppspeech {
// todo refactor later
FeaturePipelineOptions InitFeaturePipelineOptions() {
FeaturePipelineOptions opts;
opts.cmvn_file = FLAGS_cmvn_file;
opts.linear_spectrogram_opts.streaming_chunk = FLAGS_streaming_chunk;
opts.convert2PCM32 = FLAGS_convert2PCM32;
kaldi::FrameExtractionOptions frame_opts;
frame_opts.frame_length_ms = 20;
frame_opts.frame_shift_ms = 10;
frame_opts.remove_dc_offset = false;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
frame_opts.dither = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
opts.feature_cache_opts.frame_chunk_size = FLAGS_receptive_field_length;
opts.feature_cache_opts.frame_chunk_stride = FLAGS_downsampling_rate;
return opts;
}
ModelOptions InitModelOptions() {
ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
model_opts.params_path = FLAGS_params_path;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.output_names = FLAGS_model_output_names;
return model_opts;
}
TLGDecoderOptions InitDecoderOptions() {
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
return decoder_opts;
}
RecognizerResource InitRecognizerResoure() {
RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = InitFeaturePipelineOptions();
resource.model_opts = InitModelOptions();
resource.tlg_opts = InitDecoderOptions();
return resource;
}
}

@ -0,0 +1,60 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/recognizer.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
Recognizer::Recognizer(const RecognizerResource& resource) {
// resource_ = resource;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(resource.model_opts));
BaseFloat ac_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale));
decoder_.reset(new TLGDecoder(resource.tlg_opts));
input_finished_ = false;
}
void Recognizer::Accept(const Vector<BaseFloat>& waves) {
feature_pipeline_->Accept(waves);
}
void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); }
std::string Recognizer::GetFinalResult() {
return decoder_->GetFinalBestPath();
}
void Recognizer::SetFinished() {
feature_pipeline_->SetFinished();
input_finished_ = true;
}
bool Recognizer::IsFinished() { return input_finished_; }
void Recognizer::Reset() {
feature_pipeline_->Reset();
decodable_->Reset();
decoder_->Reset();
}
} // namespace ppspeech

@ -0,0 +1,59 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor later (SGoat)
#pragma once
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
namespace ppspeech {
struct RecognizerResource {
FeaturePipelineOptions feature_pipeline_opts;
ModelOptions model_opts;
TLGDecoderOptions tlg_opts;
// CTCBeamSearchOptions beam_search_opts;
kaldi::BaseFloat acoustic_scale;
RecognizerResource()
: acoustic_scale(1.0),
feature_pipeline_opts(),
model_opts(),
tlg_opts() {}
};
class Recognizer {
public:
explicit Recognizer(const RecognizerResource& resouce);
void Accept(const kaldi::Vector<kaldi::BaseFloat>& waves);
void Decode();
std::string GetFinalResult();
void SetFinished();
bool IsFinished();
void Reset();
private:
// std::shared_ptr<RecognizerResource> resource_;
// RecognizerResource resource_;
std::shared_ptr<FeaturePipeline> feature_pipeline_;
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<TLGDecoder> decoder_;
bool input_finished_;
};
} // namespace ppspeech

@ -6,6 +6,7 @@ add_library(frontend STATIC
linear_spectrogram.cc
audio_cache.cc
feature_cache.cc
feature_pipeline.cc
)
target_link_libraries(frontend PUBLIC kaldi-matrix)
target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common)

@ -41,7 +41,7 @@ void AudioCache::Accept(const VectorBase<BaseFloat>& waves) {
ready_feed_condition_.wait(lock);
}
for (size_t idx = 0; idx < waves.Dim(); ++idx) {
int32 buffer_idx = (idx + offset_) % ring_buffer_.size();
int32 buffer_idx = (idx + offset_ + size_) % ring_buffer_.size();
ring_buffer_[buffer_idx] = waves(idx);
if (convert2PCM32_)
ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx));

@ -24,7 +24,7 @@ namespace ppspeech {
class AudioCache : public FrontendInterface {
public:
explicit AudioCache(int buffer_size = 1000 * kint16max,
bool convert2PCM32 = false);
bool convert2PCM32 = true);
virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves);

@ -23,10 +23,13 @@ using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
FeatureCache::FeatureCache(int max_size,
FeatureCache::FeatureCache(FeatureCacheOptions opts,
unique_ptr<FrontendInterface> base_extractor) {
max_size_ = max_size;
max_size_ = opts.max_size;
frame_chunk_stride_ = opts.frame_chunk_stride;
frame_chunk_size_ = opts.frame_chunk_size;
base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim();
}
void FeatureCache::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
@ -44,13 +47,14 @@ bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.empty() && base_extractor_->IsFinished() == false) {
ready_read_condition_.wait(lock);
BaseFloat elapsed = timer.Elapsed() * 1000;
// todo replace 1.0 with timeout_
if (elapsed > 1.0) {
// todo refactor: wait
// ready_read_condition_.wait(lock);
int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000);
// todo replace 1 with timeout_, 1 ms
if (elapsed > 1) {
return false;
}
usleep(1000); // sleep 1 ms
usleep(100); // sleep 0.1 ms
}
if (cache_.empty()) return false;
feats->Resize(cache_.front().Dim());
@ -63,25 +67,41 @@ bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
// read all data from base_feature_extractor_ into cache_
bool FeatureCache::Compute() {
// compute and feed
Vector<BaseFloat> feature_chunk;
bool result = base_extractor_->Read(&feature_chunk);
Vector<BaseFloat> feature;
bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) return false;
int32 joint_len = feature.Dim() + remained_feature_.Dim();
int32 num_chunk =
((joint_len / dim_) - frame_chunk_size_) / frame_chunk_stride_ + 1;
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.size() >= max_size_) {
ready_feed_condition_.wait(lock);
}
Vector<BaseFloat> joint_feature(joint_len);
joint_feature.Range(0, remained_feature_.Dim())
.CopyFromVec(remained_feature_);
joint_feature.Range(remained_feature_.Dim(), feature.Dim())
.CopyFromVec(feature);
// feed cache
if (feature_chunk.Dim() != 0) {
for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
int32 start = chunk_idx * frame_chunk_stride_ * dim_;
Vector<BaseFloat> feature_chunk(frame_chunk_size_ * dim_);
SubVector<BaseFloat> tmp(joint_feature.Data() + start,
frame_chunk_size_ * dim_);
feature_chunk.CopyFromVec(tmp);
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.size() >= max_size_) {
ready_feed_condition_.wait(lock);
}
// feed cache
cache_.push(feature_chunk);
ready_read_condition_.notify_one();
}
ready_read_condition_.notify_one();
int32 remained_feature_len =
joint_len - num_chunk * frame_chunk_stride_ * dim_;
remained_feature_.Resize(remained_feature_len);
remained_feature_.CopyFromVec(joint_feature.Range(
frame_chunk_stride_ * num_chunk * dim_, remained_feature_len));
return result;
}
void Reset() {
// std::lock_guard<std::mutex> lock(mutex_);
return;
}
} // namespace ppspeech

@ -19,10 +19,18 @@
namespace ppspeech {
struct FeatureCacheOptions {
int32 max_size;
int32 frame_chunk_size;
int32 frame_chunk_stride;
FeatureCacheOptions()
: max_size(kint16max), frame_chunk_size(1), frame_chunk_stride(1) {}
};
class FeatureCache : public FrontendInterface {
public:
explicit FeatureCache(
int32 max_size = kint16max,
FeatureCacheOptions opts,
std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves
@ -32,12 +40,15 @@ class FeatureCache : public FrontendInterface {
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// feat dim
virtual size_t Dim() const { return base_extractor_->Dim(); }
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() {
// std::unique_lock<std::mutex> lock(mutex_);
base_extractor_->SetFinished();
LOG(INFO) << "set finished";
// read the last chunk data
Compute();
// ready_feed_condition_.notify_one();
}
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
@ -52,9 +63,13 @@ class FeatureCache : public FrontendInterface {
private:
bool Compute();
int32 dim_;
size_t max_size_;
std::unique_ptr<FrontendInterface> base_extractor_;
int32 frame_chunk_size_;
int32 frame_chunk_stride_;
kaldi::Vector<kaldi::BaseFloat> remained_feature_;
std::unique_ptr<FrontendInterface> base_extractor_;
std::mutex mutex_;
std::queue<kaldi::Vector<BaseFloat>> cache_;
std::condition_variable ready_feed_condition_;

@ -0,0 +1,36 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/audio/feature_pipeline.h"
namespace ppspeech {
using std::unique_ptr;
FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
unique_ptr<FrontendInterface> data_source(
new ppspeech::AudioCache(1000 * kint16max, opts.convert2PCM32));
unique_ptr<FrontendInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opts.linear_spectrogram_opts,
std::move(data_source)));
unique_ptr<FrontendInterface> cmvn(
new ppspeech::CMVN(opts.cmvn_file, std::move(linear_spectrogram)));
base_extractor_.reset(
new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn)));
}
} // ppspeech

@ -0,0 +1,57 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor later (SGoat)
#pragma once
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
namespace ppspeech {
struct FeaturePipelineOptions {
std::string cmvn_file;
bool convert2PCM32;
LinearSpectrogramOptions linear_spectrogram_opts;
FeatureCacheOptions feature_cache_opts;
FeaturePipelineOptions()
: cmvn_file(""),
convert2PCM32(false),
linear_spectrogram_opts(),
feature_cache_opts() {}
};
class FeaturePipeline : public FrontendInterface {
public:
explicit FeaturePipeline(const FeaturePipelineOptions& opts);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves) {
base_extractor_->Accept(waves);
}
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
return base_extractor_->Read(feats);
}
virtual size_t Dim() const { return base_extractor_->Dim(); }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); }
private:
std::unique_ptr<FrontendInterface> base_extractor_;
};
}

@ -52,16 +52,16 @@ bool LinearSpectrogram::Read(Vector<BaseFloat>* feats) {
if (flag == false || input_feats.Dim() == 0) return false;
int32 feat_len = input_feats.Dim();
int32 left_len = reminded_wav_.Dim();
int32 left_len = remained_wav_.Dim();
Vector<BaseFloat> waves(feat_len + left_len);
waves.Range(0, left_len).CopyFromVec(reminded_wav_);
waves.Range(0, left_len).CopyFromVec(remained_wav_);
waves.Range(left_len, feat_len).CopyFromVec(input_feats);
Compute(waves, feats);
int32 frame_shift = opts_.frame_opts.WindowShift();
int32 num_frames = kaldi::NumFrames(waves.Dim(), opts_.frame_opts);
int32 left_samples = waves.Dim() - frame_shift * num_frames;
reminded_wav_.Resize(left_samples);
reminded_wav_.CopyFromVec(
remained_wav_.Resize(left_samples);
remained_wav_.CopyFromVec(
waves.Range(frame_shift * num_frames, left_samples));
return true;
}

@ -25,12 +25,12 @@ struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts;
kaldi::BaseFloat streaming_chunk; // second
LinearSpectrogramOptions() : streaming_chunk(0.36), frame_opts() {}
LinearSpectrogramOptions() : streaming_chunk(0.1), frame_opts() {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("streaming-chunk",
&streaming_chunk,
"streaming chunk size, default: 0.36 sec");
"streaming chunk size, default: 0.1 sec");
frame_opts.Register(opts);
}
};
@ -48,7 +48,7 @@ class LinearSpectrogram : public FrontendInterface {
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
reminded_wav_.Resize(0);
remained_wav_.Resize(0);
}
private:
@ -60,7 +60,7 @@ class LinearSpectrogram : public FrontendInterface {
kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_;
std::unique_ptr<FrontendInterface> base_extractor_;
kaldi::Vector<kaldi::BaseFloat> reminded_wav_;
kaldi::Vector<kaldi::BaseFloat> remained_wav_;
int chunk_sample_size_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
};

@ -78,7 +78,6 @@ bool Decodable::AdvanceChunk() {
}
int32 nnet_dim = 0;
Vector<BaseFloat> inferences;
Matrix<BaseFloat> nnet_cache_tmp;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim);
nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_.CopyRowsFromVec(inferences);

Loading…
Cancel
Save