Merge branch 'develop' into ngram

pull/1729/head
Hui Zhang 2 years ago committed by GitHub
commit c938a450b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -27,20 +27,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--phones_dict=dump/phone_id_map.txt --phones_dict=dump/phone_id_map.txt
fi fi
# style melgan
# style melgan's Dygraph to Static Graph is not ready now
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \
--voc=style_melgan_csmsc \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/pd_infer_out \
--phones_dict=dump/phone_id_map.txt
fi
# hifigan # hifigan
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \ python3 ${BIN_DIR}/../inference.py \
--inference_dir=${train_output_path}/inference \ --inference_dir=${train_output_path}/inference \
--am=tacotron2_csmsc \ --am=tacotron2_csmsc \

@ -28,7 +28,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
--phones_dict=dump/phone_id_map.txt --phones_dict=dump/phone_id_map.txt
fi fi
# hifigan # hifigan
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/../inference.py \ python3 ${BIN_DIR}/../inference.py \

@ -109,6 +109,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
--lang=zh \ --lang=zh \
--text=${BIN_DIR}/../sentences.txt \ --text=${BIN_DIR}/../sentences.txt \
--output_dir=${train_output_path}/test_e2e \ --output_dir=${train_output_path}/test_e2e \
--phones_dict=dump/phone_id_map.txt \ --phones_dict=dump/phone_id_map.txt #\
--inference_dir=${train_output_path}/inference # --inference_dir=${train_output_path}/inference
fi fi

@ -26,7 +26,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi fi
# hifigan # hifigan
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
FLAGS_allocator_strategy=naive_best_fit \ FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \ python3 ${BIN_DIR}/../synthesize.py \

@ -40,7 +40,6 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor'] __all__ = ['ASRExecutor']
@cli_register( @cli_register(
name='paddlespeech.asr', description='Speech to text infer command.') name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor): class ASRExecutor(BaseExecutor):
@ -125,6 +124,7 @@ class ASRExecutor(BaseExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
logger.info("start to init the model")
if hasattr(self, 'model'): if hasattr(self, 'model'):
logger.info('Model had been initialized.') logger.info('Model had been initialized.')
return return
@ -140,14 +140,15 @@ class ASRExecutor(BaseExecutor):
res_path, res_path,
self.pretrained_models[tag]['ckpt_path'] + ".pdparams") self.pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
self.res_path = os.path.dirname( self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
@ -176,7 +177,6 @@ class ASRExecutor(BaseExecutor):
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
else: else:
raise Exception("wrong type") raise Exception("wrong type")
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
@ -254,12 +254,14 @@ class ASRExecutor(BaseExecutor):
else: else:
raise Exception("wrong type") raise Exception("wrong type")
logger.info("audio feat process success")
@paddle.no_grad() @paddle.no_grad()
def infer(self, model_type: str): def infer(self, model_type: str):
""" """
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
logger.info("start to infer the model to get the output")
cfg = self.config.decode cfg = self.config.decode
audio = self._inputs["audio"] audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"] audio_len = self._inputs["audio_len"]
@ -276,17 +278,22 @@ class ASRExecutor(BaseExecutor):
self._outputs["result"] = result_transcripts[0] self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
result_transcripts = self.model.decode( logger.info(f"we will use the transformer like model : {model_type}")
audio, try:
audio_len, result_transcripts = self.model.decode(
text_feature=self.text_feature, audio,
decoding_method=cfg.decoding_method, audio_len,
beam_size=cfg.beam_size, text_feature=self.text_feature,
ctc_weight=cfg.ctc_weight, decoding_method=cfg.decoding_method,
decoding_chunk_size=cfg.decoding_chunk_size, beam_size=cfg.beam_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, ctc_weight=cfg.ctc_weight,
simulate_streaming=cfg.simulate_streaming) decoding_chunk_size=cfg.decoding_chunk_size,
self._outputs["result"] = result_transcripts[0][0] num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming)
self._outputs["result"] = result_transcripts[0][0]
except Exception as e:
logger.exception(e)
else: else:
raise Exception("invalid model name") raise Exception("invalid model name")

@ -88,6 +88,8 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer": "conformer":
"paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.u2:U2Model",
"conformer_online":
"paddlespeech.s2t.models.u2:U2Model",
"transformer": "transformer":
"paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.u2:U2Model",
"wenetspeech": "wenetspeech":

@ -279,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
# TODO(Hui Zhang): if end_flag.sum() == running_size: # TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size: if end_flag.cast(paddle.int64).sum() == running_size:
break break
# 2.1 Forward decoder step # 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i) running_size, 1, 1).to(device) # (B*N, i, i)
# logp: (B*N, vocab) # logp: (B*N, vocab)
logp, cache = self.decoder.forward_one_step( logp, cache = self.decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache) encoder_out, encoder_mask, hyps, hyps_mask, cache)
# 2.2 First beam prune: select topk best prob at current time # 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag) top_k_logp = mask_finished_scores(top_k_logp, end_flag)
@ -708,11 +707,11 @@ class U2BaseModel(ASRInterface, nn.Layer):
batch_size = feats.shape[0] batch_size = feats.shape[0]
if decoding_method in ['ctc_prefix_beam_search', if decoding_method in ['ctc_prefix_beam_search',
'attention_rescoring'] and batch_size > 1: 'attention_rescoring'] and batch_size > 1:
logger.fatal( logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1' f'decoding mode {decoding_method} must be running with batch_size == 1'
) )
logger.error(f"current batch_size is {batch_size}")
sys.exit(1) sys.exit(1)
if decoding_method == 'attention': if decoding_method == 'attention':
hyps = self.recognize( hyps = self.recognize(
feats, feats,

@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase):
# init once # init once
if self._ext_scorer is not None: if self._ext_scorer is not None:
return return
if language_model_path != '': if language_model_path != '':
logger.info("begin to initialize the external scorer " logger.info("begin to initialize the external scorer "
"for decoding") "for decoding")

@ -35,3 +35,16 @@
```bash ```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
``` ```
## Online ASR Server
### Lanuch online asr server
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### Access online asr server
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
```

@ -35,3 +35,17 @@
```bash ```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
``` ```
## 流式ASR
### 启动流式语音识别服务
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### 访问流式语音识别服务
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav
```

@ -277,11 +277,12 @@ class ASRClientExecutor(BaseExecutor):
lang=lang, lang=lang,
audio_format=audio_format) audio_format=audio_format)
time_end = time.time() time_end = time.time()
logger.info(res.json()) logger.info(res)
logger.info("Response time %f s." % (time_end - time_start)) logger.info("Response time %f s." % (time_end - time_start))
return True return True
except Exception as e: except Exception as e:
logger.error("Failed to speech recognition.") logger.error("Failed to speech recognition.")
logger.error(e)
return False return False
@stats_wrapper @stats_wrapper
@ -299,9 +300,10 @@ class ASRClientExecutor(BaseExecutor):
logging.info("asr websocket client start") logging.info("asr websocket client start")
handler = ASRAudioHandler(server_ip, port) handler = ASRAudioHandler(server_ip, port)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(input)) res = loop.run_until_complete(handler.run(input))
logging.info("asr websocket client finished") logging.info("asr websocket client finished")
return res['asr_results']
@cli_client_register( @cli_client_register(
name='paddlespeech_client.cls', description='visit cls service') name='paddlespeech_client.cls', description='visit cls service')

@ -41,11 +41,7 @@ asr_online:
shift_ms: 40 shift_ms: 40
sample_rate: 16000 sample_rate: 16000
sample_width: 2 sample_width: 2
window_n: 7 # frame
vad_conf: shift_n: 4 # frame
aggressiveness: 2 window_ms: 20 # ms
sample_rate: 16000 shift_ms: 10 # ms
frame_duration_ms: 20
sample_width: 2
padding_ms: 200
padding_ratio: 0.9

@ -0,0 +1,45 @@
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'conformer_online_multicn'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2

File diff suppressed because it is too large Load Diff

@ -0,0 +1,128 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
from paddlespeech.cli.log import logger
from paddlespeech.s2t.utils.utility import log_add
__all__ = ['CTCPrefixBeamSearch']
class CTCPrefixBeamSearch:
def __init__(self, config):
"""Implement the ctc prefix beam search
Args:
config (yacs.config.CfgNode): _description_
"""
self.config = config
self.reset()
@paddle.no_grad()
def search(self, ctc_probs, device, blank_id=0):
"""ctc prefix beam search method decode a chunk feature
Args:
xs (paddle.Tensor): feature data
ctc_probs (paddle.Tensor): the ctc probability of all the tokens
device (paddle.fluid.core_avx.Place): the feature host device, such as CUDAPlace(0).
blank_id (int, optional): the blank id in the vocab. Defaults to 0.
Returns:
list: the search result
"""
# decode
logger.info("start to ctc prefix search")
batch_size = 1
beam_size = self.config.beam_size
maxlen = ctc_probs.shape[0]
assert len(ctc_probs.shape) == 2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if self.cur_hyps is None:
self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in self.cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
self.cur_hyps = next_hyps[:beam_size]
self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps]
logger.info("ctc prefix search success")
return self.hyps
def get_one_best_hyps(self):
"""Return the one best result
Returns:
list: the one best result
"""
return [self.hyps[0][0]]
def get_hyps(self):
"""Return the search hyps
Returns:
list: return the search hyps
"""
return self.hyps
def reset(self):
"""Rest the search cache value
"""
self.cur_hyps = None
self.hyps = None
def finalize_search(self):
"""do nothing in ctc_prefix_beam_search
"""
pass

@ -12,24 +12,329 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import base64 import base64
import math
import os
import time import time
from typing import Optional
import numpy as np import numpy as np
import paddle import paddle
import yaml
from yacs.config import CfgNode
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.util import denorm
from paddlespeech.server.utils.util import get_chunks from paddlespeech.server.utils.util import get_chunks
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore
__all__ = ['TTSEngine']
# support online model
pretrained_models = {
# fastspeech2
"fastspeech2_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip',
'md5':
'637d28a5e53aa60275612ba4393d5f22',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_76000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
"fastspeech2_cnndecoder_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip',
'md5':
'6eb28e22ace73e0ebe7845f86478f89f',
'config':
'cnndecoder.yaml',
'ckpt':
'snapshot_iter_153000.pdz',
'speech_stats':
'speech_stats.npy',
'phones_dict':
'phone_id_map.txt',
},
# mb_melgan
"mb_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'ee5f0604e20091f0d495b6ec4618b90d',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1000000.pdz',
'speech_stats':
'feats_stats.npy',
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
}
model_alias = {
# acoustic model
"fastspeech2":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2",
"fastspeech2_inference":
"paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference",
# voc
"mb_melgan":
"paddlespeech.t2s.models.melgan:MelGANGenerator",
"mb_melgan_inference":
"paddlespeech.t2s.models.melgan:MelGANInference",
"hifigan":
"paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
"hifigan_inference":
"paddlespeech.t2s.models.hifigan:HiFiGANInference",
}
__all__ = ['TTSEngine'] __all__ = ['TTSEngine']
class TTSServerExecutor(TTSExecutor): class TTSServerExecutor(TTSExecutor):
def __init__(self): def __init__(self, am_block, am_pad, voc_block, voc_pad):
super().__init__() super().__init__()
pass self.am_block = am_block
self.am_pad = am_pad
self.voc_block = voc_block
self.voc_pad = voc_pad
def get_model_info(self,
field: str,
model_name: str,
ckpt: Optional[os.PathLike],
stat: Optional[os.PathLike]):
"""get model information
Args:
field (str): am or voc
model_name (str): model type, support fastspeech2, higigan, mb_melgan
ckpt (Optional[os.PathLike]): ckpt file
stat (Optional[os.PathLike]): stat file, including mean and standard deviation
Returns:
[module]: model module
[Tensor]: mean
[Tensor]: standard deviation
"""
model_class = dynamic_import(model_name, model_alias)
if field == "am":
odim = self.am_config.n_mels
model = model_class(
idim=self.vocab_size, odim=odim, **self.am_config["model"])
model.set_state_dict(paddle.load(ckpt)["main_params"])
elif field == "voc":
model = model_class(**self.voc_config["generator_params"])
model.set_state_dict(paddle.load(ckpt)["generator_params"])
model.remove_weight_norm()
else:
logger.error("Please set correct field, am or voc")
model.eval()
model_mu, model_std = np.load(stat)
model_mu = paddle.to_tensor(model_mu)
model_std = paddle.to_tensor(model_std)
return model, model_mu, model_std
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(
self,
am: str='fastspeech2_csmsc',
am_config: Optional[os.PathLike]=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
voc: str='mb_melgan_csmsc',
voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None,
lang: str='zh', ):
"""
Init model and other resources from a specific path.
"""
if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'):
logger.info('Models had been initialized.')
return
# am model info
am_tag = am + '-' + lang
if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None:
am_res_path = self._get_pretrained_path(am_tag)
self.am_res_path = am_res_path
self.am_config = os.path.join(am_res_path,
pretrained_models[am_tag]['config'])
self.am_ckpt = os.path.join(am_res_path,
pretrained_models[am_tag]['ckpt'])
self.am_stat = os.path.join(
am_res_path, pretrained_models[am_tag]['speech_stats'])
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
am_res_path, pretrained_models[am_tag]['phones_dict'])
print("self.phones_dict:", self.phones_dict)
logger.info(am_res_path)
logger.info(self.am_config)
logger.info(self.am_ckpt)
else:
self.am_config = os.path.abspath(am_config)
self.am_ckpt = os.path.abspath(am_ckpt)
self.am_stat = os.path.abspath(am_stat)
self.phones_dict = os.path.abspath(phones_dict)
self.am_res_path = os.path.dirname(os.path.abspath(self.am_config))
print("self.phones_dict:", self.phones_dict)
self.tones_dict = None
self.speaker_dict = None
# voc model info
voc_tag = voc + '-' + lang
if voc_ckpt is None or voc_config is None or voc_stat is None:
voc_res_path = self._get_pretrained_path(voc_tag)
self.voc_res_path = voc_res_path
self.voc_config = os.path.join(voc_res_path,
pretrained_models[voc_tag]['config'])
self.voc_ckpt = os.path.join(voc_res_path,
pretrained_models[voc_tag]['ckpt'])
self.voc_stat = os.path.join(
voc_res_path, pretrained_models[voc_tag]['speech_stats'])
logger.info(voc_res_path)
logger.info(self.voc_config)
logger.info(self.voc_ckpt)
else:
self.voc_config = os.path.abspath(voc_config)
self.voc_ckpt = os.path.abspath(voc_ckpt)
self.voc_stat = os.path.abspath(voc_stat)
self.voc_res_path = os.path.dirname(
os.path.abspath(self.voc_config))
# Init body.
with open(self.am_config) as f:
self.am_config = CfgNode(yaml.safe_load(f))
with open(self.voc_config) as f:
self.voc_config = CfgNode(yaml.safe_load(f))
with open(self.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
self.vocab_size = len(phn_id)
print("vocab_size:", self.vocab_size)
# frontend
if lang == 'zh':
self.frontend = Frontend(
phone_vocab_path=self.phones_dict,
tone_vocab_path=self.tones_dict)
elif lang == 'en':
self.frontend = English(phone_vocab_path=self.phones_dict)
print("frontend done!")
# am infer info
self.am_name = am[:am.rindex('_')]
if self.am_name == "fastspeech2_cnndecoder":
self.am_inference, self.am_mu, self.am_std = self.get_model_info(
"am", "fastspeech2", self.am_ckpt, self.am_stat)
else:
am, am_mu, am_std = self.get_model_info("am", self.am_name,
self.am_ckpt, self.am_stat)
am_normalizer = ZScore(am_mu, am_std)
am_inference_class = dynamic_import(self.am_name + '_inference',
model_alias)
self.am_inference = am_inference_class(am_normalizer, am)
self.am_inference.eval()
print("acoustic model done!")
# voc infer info
self.voc_name = voc[:voc.rindex('_')]
voc, voc_mu, voc_std = self.get_model_info("voc", self.voc_name,
self.voc_ckpt, self.voc_stat)
voc_normalizer = ZScore(voc_mu, voc_std)
voc_inference_class = dynamic_import(self.voc_name + '_inference',
model_alias)
self.voc_inference = voc_inference_class(voc_normalizer, voc)
self.voc_inference.eval()
print("voc done!")
def get_phone(self, sentence, lang, merge_sentences, get_tone_ids):
tone_ids = None
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
def depadding(self, data, chunk_num, chunk_id, block, pad, upsample):
"""
Streaming inference removes the result of pad inference
"""
front_pad = min(chunk_id * block, pad)
# first chunk
if chunk_id == 0:
data = data[:block * upsample]
# last chunk
elif chunk_id == chunk_num - 1:
data = data[front_pad * upsample:]
# middle chunk
else:
data = data[front_pad * upsample:(front_pad + block) * upsample]
return data
@paddle.no_grad() @paddle.no_grad()
def infer( def infer(
@ -37,16 +342,20 @@ class TTSServerExecutor(TTSExecutor):
text: str, text: str,
lang: str='zh', lang: str='zh',
am: str='fastspeech2_csmsc', am: str='fastspeech2_csmsc',
spk_id: int=0, spk_id: int=0, ):
am_block: int=42,
am_pad: int=12,
voc_block: int=14,
voc_pad: int=14, ):
""" """
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:] am_block = self.am_block
am_pad = self.am_pad
am_upsample = 1
voc_block = self.voc_block
voc_pad = self.voc_pad
voc_upsample = self.voc_config.n_shift
# first_flag 用于标记首包
first_flag = 1
get_tone_ids = False get_tone_ids = False
merge_sentences = False merge_sentences = False
frontend_st = time.time() frontend_st = time.time()
@ -64,43 +373,100 @@ class TTSServerExecutor(TTSExecutor):
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
else: else:
print("lang should in {'zh', 'en'}!") print("lang should in {'zh', 'en'}!")
self.frontend_time = time.time() - frontend_st frontend_et = time.time()
self.frontend_time = frontend_et - frontend_st
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
am_st = time.time()
part_phone_ids = phone_ids[i] part_phone_ids = phone_ids[i]
# am voc_chunk_id = 0
if am_name == 'speedyspeech':
part_tone_ids = tone_ids[i] # fastspeech2_csmsc
mel = self.am_inference(part_phone_ids, part_tone_ids) if am == "fastspeech2_csmsc":
# fastspeech2 # am
mel = self.am_inference(part_phone_ids)
if first_flag == 1:
first_am_et = time.time()
self.first_am_infer = first_am_et - frontend_et
# voc streaming
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
voc_chunk_num = len(mel_chunks)
voc_st = time.time()
for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.voc_inference(mel_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num, i,
voc_block, voc_pad, voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
self.first_response_time = first_voc_et - frontend_st
first_flag = 0
yield sub_wav
# fastspeech2_cnndecoder_csmsc
elif am == "fastspeech2_cnndecoder_csmsc":
# am
orig_hs, h_masks = self.am_inference.encoder_infer(
part_phone_ids)
# streaming voc chunk info
mel_len = orig_hs.shape[1]
voc_chunk_num = math.ceil(mel_len / self.voc_block)
start = 0
end = min(self.voc_block + self.voc_pad, mel_len)
# streaming am
hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am")
am_chunk_num = len(hss)
for i, hs in enumerate(hss):
before_outs, _ = self.am_inference.decoder(hs)
after_outs = before_outs + self.am_inference.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
normalized_mel = after_outs[0]
sub_mel = denorm(normalized_mel, self.am_mu, self.am_std)
sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block,
am_pad, am_upsample)
if i == 0:
mel_streaming = sub_mel
else:
mel_streaming = np.concatenate(
(mel_streaming, sub_mel), axis=0)
# streaming voc
# 当流式AM推理的mel帧数大于流式voc推理的chunk size开始进行流式voc 推理
while (mel_streaming.shape[0] >= end and
voc_chunk_id < voc_chunk_num):
if first_flag == 1:
first_am_et = time.time()
self.first_am_infer = first_am_et - frontend_et
voc_chunk = mel_streaming[start:end, :]
voc_chunk = paddle.to_tensor(voc_chunk)
sub_wav = self.voc_inference(voc_chunk)
sub_wav = self.depadding(sub_wav, voc_chunk_num,
voc_chunk_id, voc_block,
voc_pad, voc_upsample)
if first_flag == 1:
first_voc_et = time.time()
self.first_voc_infer = first_voc_et - first_am_et
self.first_response_time = first_voc_et - frontend_st
first_flag = 0
yield sub_wav
voc_chunk_id += 1
start = max(0, voc_chunk_id * voc_block - voc_pad)
end = min((voc_chunk_id + 1) * voc_block + voc_pad,
mel_len)
else: else:
# multi speaker logger.error(
if am_dataset in {"aishell3", "vctk"}: "Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts."
mel = self.am_inference( )
part_phone_ids, spk_id=paddle.to_tensor(spk_id))
else: self.final_response_time = time.time() - frontend_st
mel = self.am_inference(part_phone_ids)
am_et = time.time()
# voc streaming
voc_upsample = self.voc_config.n_shift
mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc")
chunk_num = len(mel_chunks)
voc_st = time.time()
for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.voc_inference(mel_chunk)
front_pad = min(i * voc_block, voc_pad)
if i == 0:
sub_wav = sub_wav[:voc_block * voc_upsample]
elif i == chunk_num - 1:
sub_wav = sub_wav[front_pad * voc_upsample:]
else:
sub_wav = sub_wav[front_pad * voc_upsample:(
front_pad + voc_block) * voc_upsample]
yield sub_wav
class TTSEngine(BaseEngine): class TTSEngine(BaseEngine):
@ -113,14 +479,21 @@ class TTSEngine(BaseEngine):
def __init__(self, name=None): def __init__(self, name=None):
"""Initialize TTS server engine """Initialize TTS server engine
""" """
super(TTSEngine, self).__init__() super().__init__()
def init(self, config: dict) -> bool: def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
self.config = config self.config = config
assert "fastspeech2_csmsc" in config.am and ( assert (
config.voc == "hifigan_csmsc-zh" or config.voc == "mb_melgan_csmsc" config.am == "fastspeech2_csmsc" or
config.am == "fastspeech2_cnndecoder_csmsc"
) and (
config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
assert (
config.voc_block > 0 and config.voc_pad > 0
), "Please set correct voc_block and voc_pad, they should be more than 0."
try: try:
if self.config.device: if self.config.device:
self.device = self.config.device self.device = self.config.device
@ -135,6 +508,9 @@ class TTSEngine(BaseEngine):
(self.device)) (self.device))
return False return False
self.executor = TTSServerExecutor(config.am_block, config.am_pad,
config.voc_block, config.voc_pad)
try: try:
self.executor._init_from_path( self.executor._init_from_path(
am=self.config.am, am=self.config.am,
@ -155,15 +531,42 @@ class TTSEngine(BaseEngine):
(self.device)) (self.device))
return False return False
self.am_block = self.config.am_block
self.am_pad = self.config.am_pad
self.voc_block = self.config.voc_block
self.voc_pad = self.config.voc_pad
logger.info("Initialize TTS server engine successfully on device: %s." % logger.info("Initialize TTS server engine successfully on device: %s." %
(self.device)) (self.device))
# warm up
try:
self.warm_up()
except Exception as e:
logger.error("Failed to warm up on tts engine.")
return False
return True return True
def warm_up(self):
"""warm up
"""
if self.config.lang == 'zh':
sentence = "您好,欢迎使用语音合成服务。"
if self.config.lang == 'en':
sentence = "Hello and welcome to the speech synthesis service."
logger.info(
"*******************************warm up ********************************"
)
for i in range(3):
for wav in self.executor.infer(
text=sentence,
lang=self.config.lang,
am=self.config.am,
spk_id=0, ):
logger.info(
f"The first response time of the {i} warm up: {self.executor.first_response_time} s"
)
break
logger.info(
"**********************************************************************"
)
def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): def preprocess(self, text_bese64: str=None, text_bytes: bytes=None):
# Convert byte to text # Convert byte to text
if text_bese64: if text_bese64:
@ -195,18 +598,14 @@ class TTSEngine(BaseEngine):
wav_base64: The base64 format of the synthesized audio. wav_base64: The base64 format of the synthesized audio.
""" """
lang = self.config.lang
wav_list = [] wav_list = []
for wav in self.executor.infer( for wav in self.executor.infer(
text=sentence, text=sentence,
lang=lang, lang=self.config.lang,
am=self.config.am, am=self.config.am,
spk_id=spk_id, spk_id=spk_id, ):
am_block=self.am_block,
am_pad=self.am_pad,
voc_block=self.voc_block,
voc_pad=self.voc_pad):
# wav type: <class 'numpy.ndarray'> float32, convert to pcm (base64) # wav type: <class 'numpy.ndarray'> float32, convert to pcm (base64)
wav = float2pcm(wav) # float32 to int16 wav = float2pcm(wav) # float32 to int16
wav_bytes = wav.tobytes() # to bytes wav_bytes = wav.tobytes() # to bytes
@ -216,5 +615,14 @@ class TTSEngine(BaseEngine):
yield wav_base64 yield wav_base64
wav_all = np.concatenate(wav_list, axis=0) wav_all = np.concatenate(wav_list, axis=0)
logger.info("The durations of audio is: {} s".format( duration = len(wav_all) / self.executor.am_config.fs
len(wav_all) / self.executor.am_config.fs)) logger.info(f"sentence: {sentence}")
logger.info(f"The durations of audio is: {duration} s")
logger.info(
f"first response time: {self.executor.first_response_time} s")
logger.info(
f"final response time: {self.executor.final_response_time} s")
logger.info(f"RTF: {self.executor.final_response_time / duration}")
logger.info(
f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s,"
)

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -34,10 +34,9 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str): def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples) x_len = len(samples)
# chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size != 0: chunk_size = 85 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size!= 0:
padding_len_x = chunk_size - x_len % chunk_size padding_len_x = chunk_size - x_len % chunk_size
else: else:
padding_len_x = 0 padding_len_x = 0
@ -48,7 +47,6 @@ class ASRAudioHandler:
assert (x_len + padding_len_x) % chunk_size == 0 assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk) num_chunk = int(num_chunk)
for i in range(0, num_chunk): for i in range(0, num_chunk):
start = i * chunk_size start = i * chunk_size
end = start + chunk_size end = start + chunk_size
@ -57,7 +55,11 @@ class ASRAudioHandler:
async def run(self, wavfile_path: str): async def run(self, wavfile_path: str):
logging.info("send a message to the server") logging.info("send a message to the server")
# self.read_wave()
# send websocket handshake protocal
async with websockets.connect(self.url) as ws: async with websockets.connect(self.url) as ws:
# server has already received handshake protocal
# client start to send the command
audio_info = json.dumps( audio_info = json.dumps(
{ {
"name": "test.wav", "name": "test.wav",
@ -78,7 +80,6 @@ class ASRAudioHandler:
msg = json.loads(msg) msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
result = msg
# finished # finished
audio_info = json.dumps( audio_info = json.dumps(
{ {
@ -91,10 +92,12 @@ class ASRAudioHandler:
separators=(',', ': ')) separators=(',', ': '))
await ws.send(audio_info) await ws.send(audio_info)
msg = await ws.recv() msg = await ws.recv()
# decode the bytes to str
msg = json.loads(msg) msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("final receive msg={}".format(msg))
result = msg
return result return result
def main(args): def main(args):

@ -63,12 +63,12 @@ class ChunkBuffer(object):
the sample rate. the sample rate.
Yields Frames of the requested duration. Yields Frames of the requested duration.
""" """
audio = self.remained_audio + audio audio = self.remained_audio + audio
self.remained_audio = b'' self.remained_audio = b''
offset = 0 offset = 0
timestamp = 0.0 timestamp = 0.0
while offset + self.window_bytes <= len(audio): while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp, yield Frame(audio[offset:offset + self.window_bytes], timestamp,
self.window_sec) self.window_sec)

@ -52,6 +52,10 @@ def get_chunks(data, block_size, pad_size, step):
Returns: Returns:
list: chunks list list: chunks list
""" """
if block_size == -1:
return [data]
if step == "am": if step == "am":
data_len = data.shape[1] data_len = data.shape[1]
elif step == "voc": elif step == "voc":

@ -13,12 +13,12 @@
# limitations under the License. # limitations under the License.
import json import json
import numpy as np
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import WebSocket from fastapi import WebSocket
from fastapi import WebSocketDisconnect from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio from paddlespeech.server.utils.vad import VADAudio
@ -28,26 +28,29 @@ router = APIRouter()
@router.websocket('/ws/asr') @router.websocket('/ws/asr')
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
await websocket.accept() await websocket.accept()
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
connection_handler = None
# init buffer # init buffer
# each websocekt connection has its own chunk buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer( chunk_buffer = ChunkBuffer(
window_n=7, window_n=chunk_buffer_conf.window_n,
shift_n=4, shift_n=chunk_buffer_conf.shift_n,
window_ms=20, window_ms=chunk_buffer_conf.window_ms,
shift_ms=10, shift_ms=chunk_buffer_conf.shift_ms,
sample_rate=chunk_buffer_conf['sample_rate'], sample_rate=chunk_buffer_conf.sample_rate,
sample_width=chunk_buffer_conf['sample_width']) sample_width=chunk_buffer_conf.sample_width)
# init vad # init vad
vad_conf = asr_engine.config.vad_conf vad_conf = asr_engine.config.get('vad_conf', None)
vad = VADAudio( if vad_conf:
aggressiveness=vad_conf['aggressiveness'], vad = VADAudio(
rate=vad_conf['sample_rate'], aggressiveness=vad_conf['aggressiveness'],
frame_duration_ms=vad_conf['frame_duration_ms']) rate=vad_conf['sample_rate'],
frame_duration_ms=vad_conf['frame_duration_ms'])
try: try:
while True: while True:
@ -64,13 +67,21 @@ async def websocket_endpoint(websocket: WebSocket):
if message['signal'] == 'start': if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"} resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here # do something at begining here
# create the instance to process the audio
connection_handler = PaddleASRConnectionHanddler(asr_engine)
await websocket.send_json(resp) await websocket.send_json(resp)
elif message['signal'] == 'end': elif message['signal'] == 'end':
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
# reset single engine for an new connection # reset single engine for an new connection
asr_engine.reset() connection_handler.decode(is_finished=True)
resp = {"status": "ok", "signal": "finished"} connection_handler.rescoring()
asr_results = connection_handler.get_result()
connection_handler.reset()
resp = {
"status": "ok",
"signal": "finished",
'asr_results': asr_results
}
await websocket.send_json(resp) await websocket.send_json(resp)
break break
else: else:
@ -79,21 +90,11 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message: elif "bytes" in message:
message = message["bytes"] message = message["bytes"]
engine_pool = get_engine_pool() connection_handler.extract_feat(message)
asr_engine = engine_pool['asr'] connection_handler.decode(is_finished=False)
asr_results = "" asr_results = connection_handler.get_result()
frames = chunk_buffer.frame_generator(message)
for frame in frames:
samples = np.frombuffer(frame.bytes, dtype=np.int16)
sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate)
asr_engine.run(x_chunk, x_chunk_lens)
asr_results = asr_engine.postprocess()
asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results} resp = {'asr_results': asr_results}
await websocket.send_json(resp) await websocket.send_json(resp)
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass

@ -14,6 +14,7 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -101,21 +102,35 @@ def parse_args():
# only inference for models trained with csmsc now # only inference for models trained with csmsc now
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
# frontend # frontend
frontend = get_frontend(args) frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# am_predictor # am_predictor
am_predictor = get_predictor(args, filed='am') am_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + ".pdmodel",
params_file=args.am + ".pdiparams",
device=args.device)
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
# voc_predictor # voc_predictor
voc_predictor = get_predictor(args, filed='voc') voc_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.voc + ".pdmodel",
params_file=args.voc + ".pdiparams",
device=args.device)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args) sentences = get_sentences(text_file=args.text, lang=args.lang)
merge_sentences = True merge_sentences = True
fs = 24000 if am_dataset != 'ljspeech' else 22050 fs = 24000 if am_dataset != 'ljspeech' else 22050
@ -123,11 +138,13 @@ def main():
for utt_id, sentence in sentences[:3]: for utt_id, sentence in sentences[:3]:
with timer() as t: with timer() as t:
am_output_data = get_am_output( am_output_data = get_am_output(
args, input=sentence,
am_predictor=am_predictor, am_predictor=am_predictor,
am=args.am,
frontend=frontend, frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
input=sentence) speaker_dict=args.speaker_dict, )
wav = get_voc_output( wav = get_voc_output(
voc_predictor=voc_predictor, input=am_output_data) voc_predictor=voc_predictor, input=am_output_data)
speed = wav.size / t.elapse speed = wav.size / t.elapse
@ -143,11 +160,13 @@ def main():
for utt_id, sentence in sentences: for utt_id, sentence in sentences:
with timer() as t: with timer() as t:
am_output_data = get_am_output( am_output_data = get_am_output(
args, input=sentence,
am_predictor=am_predictor, am_predictor=am_predictor,
am=args.am,
frontend=frontend, frontend=frontend,
lang=args.lang,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
input=sentence) speaker_dict=args.speaker_dict, )
wav = get_voc_output( wav = get_voc_output(
voc_predictor=voc_predictor, input=am_output_data) voc_predictor=voc_predictor, input=am_output_data)

@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -25,7 +26,6 @@ from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_predictor from paddlespeech.t2s.exps.syn_utils import get_predictor
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_predictor
from paddlespeech.t2s.exps.syn_utils import get_voc_output from paddlespeech.t2s.exps.syn_utils import get_voc_output
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
@ -101,23 +101,47 @@ def parse_args():
# only inference for models trained with csmsc now # only inference for models trained with csmsc now
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
# frontend # frontend
frontend = get_frontend(args) frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# am_predictor # am_predictor
am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor = get_streaming_am_predictor(
args) am_encoder_infer_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_encoder_infer" + ".pdmodel",
params_file=args.am + "_am_encoder_infer" + ".pdiparams",
device=args.device)
am_decoder_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_decoder" + ".pdmodel",
params_file=args.am + "_am_decoder" + ".pdiparams",
device=args.device)
am_postnet_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.am + "_am_postnet" + ".pdmodel",
params_file=args.am + "_am_postnet" + ".pdiparams",
device=args.device)
am_mu, am_std = np.load(args.am_stat) am_mu, am_std = np.load(args.am_stat)
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
# voc_predictor # voc_predictor
voc_predictor = get_predictor(args, filed='voc') voc_predictor = get_predictor(
model_dir=args.inference_dir,
model_file=args.voc + ".pdmodel",
params_file=args.voc + ".pdiparams",
device=args.device)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args) sentences = get_sentences(text_file=args.text, lang=args.lang)
merge_sentences = True merge_sentences = True
@ -126,13 +150,13 @@ def main():
for utt_id, sentence in sentences[:3]: for utt_id, sentence in sentences[:3]:
with timer() as t: with timer() as t:
normalized_mel = get_streaming_am_output( normalized_mel = get_streaming_am_output(
args, input=sentence,
am_encoder_infer_predictor=am_encoder_infer_predictor, am_encoder_infer_predictor=am_encoder_infer_predictor,
am_decoder_predictor=am_decoder_predictor, am_decoder_predictor=am_decoder_predictor,
am_postnet_predictor=am_postnet_predictor, am_postnet_predictor=am_postnet_predictor,
frontend=frontend, frontend=frontend,
merge_sentences=merge_sentences, lang=args.lang,
input=sentence) merge_sentences=merge_sentences, )
mel = denorm(normalized_mel, am_mu, am_std) mel = denorm(normalized_mel, am_mu, am_std)
wav = get_voc_output(voc_predictor=voc_predictor, input=mel) wav = get_voc_output(voc_predictor=voc_predictor, input=mel)
speed = wav.size / t.elapse speed = wav.size / t.elapse

@ -16,6 +16,7 @@ from pathlib import Path
import jsonlines import jsonlines
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -25,12 +26,13 @@ from paddlespeech.t2s.utils import str2bool
def ort_predict(args): def ort_predict(args):
# construct dataset for evaluation # construct dataset for evaluation
with jsonlines.open(args.test_metadata, 'r') as reader: with jsonlines.open(args.test_metadata, 'r') as reader:
test_metadata = list(reader) test_metadata = list(reader)
am_name = args.am[:args.am.rindex('_')] am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset) test_dataset = get_test_dataset(test_metadata=test_metadata, am=args.am)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
@ -38,10 +40,18 @@ def ort_predict(args):
fs = 24000 if am_dataset != 'ljspeech' else 22050 fs = 24000 if am_dataset != 'ljspeech' else 22050
# am # am
am_sess = get_sess(args, filed='am') am_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# vocoder # vocoder
voc_sess = get_sess(args, filed='voc') voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# am warmup # am warmup
for T in [27, 38, 54]: for T in [27, 38, 54]:
@ -135,6 +145,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)

@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -27,21 +28,31 @@ from paddlespeech.t2s.utils import str2bool
def ort_predict(args): def ort_predict(args):
# frontend # frontend
frontend = get_frontend(args) frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args) sentences = get_sentences(text_file=args.text, lang=args.lang)
am_name = args.am[:args.am.rindex('_')] am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
fs = 24000 if am_dataset != 'ljspeech' else 22050 fs = 24000 if am_dataset != 'ljspeech' else 22050
# am am_sess = get_sess(
am_sess = get_sess(args, filed='am') model_dir=args.inference_dir,
model_file=args.am + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# vocoder # vocoder
voc_sess = get_sess(args, filed='voc') voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# frontend warmup # frontend warmup
# Loading model cost 0.5+ seconds # Loading model cost 0.5+ seconds
@ -168,6 +179,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)

@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
@ -23,30 +24,50 @@ from paddlespeech.t2s.exps.syn_utils import get_chunks
from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.exps.syn_utils import get_sess from paddlespeech.t2s.exps.syn_utils import get_sess
from paddlespeech.t2s.exps.syn_utils import get_streaming_am_sess
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.utils import str2bool
def ort_predict(args): def ort_predict(args):
# frontend # frontend
frontend = get_frontend(args) frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
sentences = get_sentences(args) sentences = get_sentences(text_file=args.text, lang=args.lang)
am_name = args.am[:args.am.rindex('_')] am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
fs = 24000 if am_dataset != 'ljspeech' else 22050 fs = 24000 if am_dataset != 'ljspeech' else 22050
# am # streaming acoustic model
am_encoder_infer_sess, am_decoder_sess, am_postnet_sess = get_streaming_am_sess( am_encoder_infer_sess = get_sess(
args) model_dir=args.inference_dir,
model_file=args.am + "_am_encoder_infer" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_decoder_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + "_am_decoder" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_postnet_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.am + "_am_postnet" + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
am_mu, am_std = np.load(args.am_stat) am_mu, am_std = np.load(args.am_stat)
# vocoder # vocoder
voc_sess = get_sess(args, filed='voc') voc_sess = get_sess(
model_dir=args.inference_dir,
model_file=args.voc + ".onnx",
device=args.device,
cpu_threads=args.cpu_threads)
# frontend warmup # frontend warmup
# Loading model cost 0.5+ seconds # Loading model cost 0.5+ seconds
@ -226,6 +247,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)

@ -14,6 +14,10 @@
import math import math
import os import os
from pathlib import Path from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
@ -21,6 +25,7 @@ import paddle
from paddle import inference from paddle import inference
from paddle import jit from paddle import jit
from paddle.static import InputSpec from paddle.static import InputSpec
from yacs.config import CfgNode
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
@ -70,7 +75,7 @@ def denorm(data, mean, std):
return data * std + mean return data * std + mean
def get_chunks(data, chunk_size, pad_size): def get_chunks(data, chunk_size: int, pad_size: int):
data_len = data.shape[1] data_len = data.shape[1]
chunks = [] chunks = []
n = math.ceil(data_len / chunk_size) n = math.ceil(data_len / chunk_size)
@ -82,28 +87,34 @@ def get_chunks(data, chunk_size, pad_size):
# input # input
def get_sentences(args): def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
# construct dataset for evaluation # construct dataset for evaluation
sentences = [] sentences = []
with open(args.text, 'rt') as f: with open(text_file, 'rt') as f:
for line in f: for line in f:
items = line.strip().split() items = line.strip().split()
utt_id = items[0] utt_id = items[0]
if 'lang' in args and args.lang == 'zh': if lang == 'zh':
sentence = "".join(items[1:]) sentence = "".join(items[1:])
elif 'lang' in args and args.lang == 'en': elif lang == 'en':
sentence = " ".join(items[1:]) sentence = " ".join(items[1:])
sentences.append((utt_id, sentence)) sentences.append((utt_id, sentence))
return sentences return sentences
def get_test_dataset(args, test_metadata, am_name, am_dataset): def get_test_dataset(test_metadata: List[Dict[str, Any]],
am: str,
speaker_dict: Optional[os.PathLike]=None,
voice_cloning: bool=False):
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
fields = ["utt_id", "text"] fields = ["utt_id", "text"]
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
print("multiple speaker fastspeech2!") print("multiple speaker fastspeech2!")
fields += ["spk_id"] fields += ["spk_id"]
elif 'voice_cloning' in args and args.voice_cloning: elif voice_cloning:
print("voice cloning!") print("voice cloning!")
fields += ["spk_emb"] fields += ["spk_emb"]
else: else:
@ -112,7 +123,7 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset):
fields = ["utt_id", "phones", "tones"] fields = ["utt_id", "phones", "tones"]
elif am_name == 'tacotron2': elif am_name == 'tacotron2':
fields = ["utt_id", "text"] fields = ["utt_id", "text"]
if 'voice_cloning' in args and args.voice_cloning: if voice_cloning:
print("voice cloning!") print("voice cloning!")
fields += ["spk_emb"] fields += ["spk_emb"]
@ -121,12 +132,14 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset):
# frontend # frontend
def get_frontend(args): def get_frontend(lang: str='zh',
if 'lang' in args and args.lang == 'zh': phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None):
if lang == 'zh':
frontend = Frontend( frontend = Frontend(
phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) phone_vocab_path=phones_dict, tone_vocab_path=tones_dict)
elif 'lang' in args and args.lang == 'en': elif lang == 'en':
frontend = English(phone_vocab_path=args.phones_dict) frontend = English(phone_vocab_path=phones_dict)
else: else:
print("wrong lang!") print("wrong lang!")
print("frontend done!") print("frontend done!")
@ -134,30 +147,37 @@ def get_frontend(args):
# dygraph # dygraph
def get_am_inference(args, am_config): def get_am_inference(
with open(args.phones_dict, "r") as f: am: str='fastspeech2_csmsc',
am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None, ):
with open(phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id) vocab_size = len(phn_id)
print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
tone_size = None tone_size = None
if 'tones_dict' in args and args.tones_dict: if tones_dict is not None:
with open(args.tones_dict, "r") as f: with open(tones_dict, "r") as f:
tone_id = [line.strip().split() for line in f.readlines()] tone_id = [line.strip().split() for line in f.readlines()]
tone_size = len(tone_id) tone_size = len(tone_id)
print("tone_size:", tone_size) print("tone_size:", tone_size)
spk_num = None spk_num = None
if 'speaker_dict' in args and args.speaker_dict: if speaker_dict is not None:
with open(args.speaker_dict, 'rt') as f: with open(speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()] spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id) spk_num = len(spk_id)
print("spk_num:", spk_num) print("spk_num:", spk_num)
odim = am_config.n_mels odim = am_config.n_mels
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
am_class = dynamic_import(am_name, model_alias) am_class = dynamic_import(am_name, model_alias)
am_inference_class = dynamic_import(am_name + '_inference', model_alias) am_inference_class = dynamic_import(am_name + '_inference', model_alias)
@ -174,34 +194,38 @@ def get_am_inference(args, am_config):
elif am_name == 'tacotron2': elif am_name == 'tacotron2':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) am.set_state_dict(paddle.load(am_ckpt)["main_params"])
am.eval() am.eval()
am_mu, am_std = np.load(args.am_stat) am_mu, am_std = np.load(am_stat)
am_mu = paddle.to_tensor(am_mu) am_mu = paddle.to_tensor(am_mu)
am_std = paddle.to_tensor(am_std) am_std = paddle.to_tensor(am_std)
am_normalizer = ZScore(am_mu, am_std) am_normalizer = ZScore(am_mu, am_std)
am_inference = am_inference_class(am_normalizer, am) am_inference = am_inference_class(am_normalizer, am)
am_inference.eval() am_inference.eval()
print("acoustic model done!") print("acoustic model done!")
return am_inference, am_name, am_dataset return am_inference
def get_voc_inference(args, voc_config): def get_voc_inference(
voc: str='pwgan_csmsc',
voc_config: Optional[os.PathLike]=None,
voc_ckpt: Optional[os.PathLike]=None,
voc_stat: Optional[os.PathLike]=None, ):
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
voc_name = args.voc[:args.voc.rindex('_')] voc_name = voc[:voc.rindex('_')]
voc_class = dynamic_import(voc_name, model_alias) voc_class = dynamic_import(voc_name, model_alias)
voc_inference_class = dynamic_import(voc_name + '_inference', model_alias) voc_inference_class = dynamic_import(voc_name + '_inference', model_alias)
if voc_name != 'wavernn': if voc_name != 'wavernn':
voc = voc_class(**voc_config["generator_params"]) voc = voc_class(**voc_config["generator_params"])
voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"]) voc.set_state_dict(paddle.load(voc_ckpt)["generator_params"])
voc.remove_weight_norm() voc.remove_weight_norm()
voc.eval() voc.eval()
else: else:
voc = voc_class(**voc_config["model"]) voc = voc_class(**voc_config["model"])
voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"]) voc.set_state_dict(paddle.load(voc_ckpt)["main_params"])
voc.eval() voc.eval()
voc_mu, voc_std = np.load(args.voc_stat) voc_mu, voc_std = np.load(voc_stat)
voc_mu = paddle.to_tensor(voc_mu) voc_mu = paddle.to_tensor(voc_mu)
voc_std = paddle.to_tensor(voc_std) voc_std = paddle.to_tensor(voc_std)
voc_normalizer = ZScore(voc_mu, voc_std) voc_normalizer = ZScore(voc_mu, voc_std)
@ -211,10 +235,16 @@ def get_voc_inference(args, voc_config):
return voc_inference return voc_inference
# to static # dygraph to static graph
def am_to_static(args, am_inference, am_name, am_dataset): def am_to_static(am_inference,
am: str='fastspeech2_csmsc',
inference_dir=Optional[os.PathLike],
speaker_dict: Optional[os.PathLike]=None):
# model: {model_name}_{dataset}
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,
input_spec=[ input_spec=[
@ -226,7 +256,7 @@ def am_to_static(args, am_inference, am_name, am_dataset):
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,
input_spec=[ input_spec=[
@ -247,56 +277,64 @@ def am_to_static(args, am_inference, am_name, am_dataset):
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am)) paddle.jit.save(am_inference, os.path.join(inference_dir, am))
am_inference = paddle.jit.load(os.path.join(args.inference_dir, args.am)) am_inference = paddle.jit.load(os.path.join(inference_dir, am))
return am_inference return am_inference
def voc_to_static(args, voc_inference): def voc_to_static(voc_inference,
voc: str='pwgan_csmsc',
inference_dir=Optional[os.PathLike]):
voc_inference = jit.to_static( voc_inference = jit.to_static(
voc_inference, input_spec=[ voc_inference, input_spec=[
InputSpec([-1, 80], dtype=paddle.float32), InputSpec([-1, 80], dtype=paddle.float32),
]) ])
paddle.jit.save(voc_inference, os.path.join(args.inference_dir, args.voc)) paddle.jit.save(voc_inference, os.path.join(inference_dir, voc))
voc_inference = paddle.jit.load(os.path.join(args.inference_dir, args.voc)) voc_inference = paddle.jit.load(os.path.join(inference_dir, voc))
return voc_inference return voc_inference
# inference # inference
def get_predictor(args, filed='am'): def get_predictor(model_dir: Optional[os.PathLike]=None,
full_name = '' model_file: Optional[os.PathLike]=None,
if filed == 'am': params_file: Optional[os.PathLike]=None,
full_name = args.am device: str='cpu'):
elif filed == 'voc':
full_name = args.voc
config = inference.Config( config = inference.Config(
str(Path(args.inference_dir) / (full_name + ".pdmodel")), str(Path(model_dir) / model_file), str(Path(model_dir) / params_file))
str(Path(args.inference_dir) / (full_name + ".pdiparams"))) if device == "gpu":
if args.device == "gpu":
config.enable_use_gpu(100, 0) config.enable_use_gpu(100, 0)
elif args.device == "cpu": elif device == "cpu":
config.disable_gpu() config.disable_gpu()
config.enable_memory_optim() config.enable_memory_optim()
predictor = inference.create_predictor(config) predictor = inference.create_predictor(config)
return predictor return predictor
def get_am_output(args, am_predictor, frontend, merge_sentences, input): def get_am_output(
am_name = args.am[:args.am.rindex('_')] input: str,
am_dataset = args.am[args.am.rindex('_') + 1:] am_predictor,
am,
frontend,
lang: str='zh',
merge_sentences: bool=True,
speaker_dict: Optional[os.PathLike]=None,
spk_id: int=0, ):
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
am_input_names = am_predictor.get_input_names() am_input_names = am_predictor.get_input_names()
get_tone_ids = False get_tone_ids = False
get_spk_id = False get_spk_id = False
if am_name == 'speedyspeech': if am_name == 'speedyspeech':
get_tone_ids = True get_tone_ids = True
if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: if am_dataset in {"aishell3", "vctk"} and speaker_dict:
get_spk_id = True get_spk_id = True
spk_id = np.array([args.spk_id]) spk_id = np.array([spk_id])
if args.lang == 'zh': if lang == 'zh':
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
elif args.lang == 'en': elif lang == 'en':
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences) input, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
@ -338,50 +376,6 @@ def get_voc_output(voc_predictor, input):
return wav return wav
# streaming am
def get_streaming_am_predictor(args):
full_name = args.am
am_encoder_infer_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_encoder_infer" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_encoder_infer" + ".pdiparams")))
am_decoder_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_decoder" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_decoder" + ".pdiparams")))
am_postnet_config = inference.Config(
str(
Path(args.inference_dir) /
(full_name + "_am_postnet" + ".pdmodel")),
str(
Path(args.inference_dir) /
(full_name + "_am_postnet" + ".pdiparams")))
if args.device == "gpu":
am_encoder_infer_config.enable_use_gpu(100, 0)
am_decoder_config.enable_use_gpu(100, 0)
am_postnet_config.enable_use_gpu(100, 0)
elif args.device == "cpu":
am_encoder_infer_config.disable_gpu()
am_decoder_config.disable_gpu()
am_postnet_config.disable_gpu()
am_encoder_infer_config.enable_memory_optim()
am_decoder_config.enable_memory_optim()
am_postnet_config.enable_memory_optim()
am_encoder_infer_predictor = inference.create_predictor(
am_encoder_infer_config)
am_decoder_predictor = inference.create_predictor(am_decoder_config)
am_postnet_predictor = inference.create_predictor(am_postnet_config)
return am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor
def get_am_sublayer_output(am_sublayer_predictor, input): def get_am_sublayer_output(am_sublayer_predictor, input):
am_sublayer_input_names = am_sublayer_predictor.get_input_names() am_sublayer_input_names = am_sublayer_predictor.get_input_names()
input_handle = am_sublayer_predictor.get_input_handle( input_handle = am_sublayer_predictor.get_input_handle(
@ -397,11 +391,15 @@ def get_am_sublayer_output(am_sublayer_predictor, input):
return am_sublayer_output return am_sublayer_output
def get_streaming_am_output(args, am_encoder_infer_predictor, def get_streaming_am_output(input: str,
am_decoder_predictor, am_postnet_predictor, am_encoder_infer_predictor,
frontend, merge_sentences, input): am_decoder_predictor,
am_postnet_predictor,
frontend,
lang: str='zh',
merge_sentences: bool=True):
get_tone_ids = False get_tone_ids = False
if args.lang == 'zh': if lang == 'zh':
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"] phone_ids = input_ids["phone_ids"]
@ -423,58 +421,27 @@ def get_streaming_am_output(args, am_encoder_infer_predictor,
return normalized_mel return normalized_mel
def get_sess(args, filed='am'): # onnx
full_name = '' def get_sess(model_dir: Optional[os.PathLike]=None,
if filed == 'am': model_file: Optional[os.PathLike]=None,
full_name = args.am device: str='cpu',
elif filed == 'voc': cpu_threads: int=1,
full_name = args.voc use_trt: bool=False):
model_dir = str(Path(args.inference_dir) / (full_name + ".onnx"))
model_dir = str(Path(model_dir) / model_file)
sess_options = ort.SessionOptions() sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if args.device == "gpu": if device == "gpu":
# fastspeech2/mb_melgan can't use trt now! # fastspeech2/mb_melgan can't use trt now!
if args.use_trt: if use_trt:
providers = ['TensorrtExecutionProvider'] providers = ['TensorrtExecutionProvider']
else: else:
providers = ['CUDAExecutionProvider'] providers = ['CUDAExecutionProvider']
elif args.device == "cpu": elif device == "cpu":
providers = ['CPUExecutionProvider'] providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = args.cpu_threads sess_options.intra_op_num_threads = cpu_threads
sess = ort.InferenceSession( sess = ort.InferenceSession(
model_dir, providers=providers, sess_options=sess_options) model_dir, providers=providers, sess_options=sess_options)
return sess return sess
# streaming am
def get_streaming_am_sess(args):
full_name = args.am
am_encoder_infer_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_encoder_infer" + ".onnx"))
am_decoder_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_decoder" + ".onnx"))
am_postnet_model_dir = str(
Path(args.inference_dir) / (full_name + "_am_postnet" + ".onnx"))
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
if args.device == "gpu":
# fastspeech2/mb_melgan can't use trt now!
if args.use_trt:
providers = ['TensorrtExecutionProvider']
else:
providers = ['CUDAExecutionProvider']
elif args.device == "cpu":
providers = ['CPUExecutionProvider']
sess_options.intra_op_num_threads = args.cpu_threads
am_encoder_infer_sess = ort.InferenceSession(
am_encoder_infer_model_dir,
providers=providers,
sess_options=sess_options)
am_decoder_sess = ort.InferenceSession(
am_decoder_model_dir, providers=providers, sess_options=sess_options)
am_postnet_sess = ort.InferenceSession(
am_postnet_model_dir, providers=providers, sess_options=sess_options)
return am_encoder_infer_sess, am_decoder_sess, am_postnet_sess

@ -50,11 +50,29 @@ def evaluate(args):
print(voc_config) print(voc_config)
# acoustic model # acoustic model
am_inference, am_name, am_dataset = get_am_inference(args, am_config) am_name = args.am[:args.am.rindex('_')]
test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset) am_dataset = args.am[args.am.rindex('_') + 1:]
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict)
test_dataset = get_test_dataset(
test_metadata=test_metadata,
am=args.am,
speaker_dict=args.speaker_dict,
voice_cloning=args.voice_cloning)
# vocoder # vocoder
voc_inference = get_voc_inference(args, voc_config) voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

@ -42,24 +42,48 @@ def evaluate(args):
print(am_config) print(am_config)
print(voc_config) print(voc_config)
sentences = get_sentences(args) sentences = get_sentences(text_file=args.text, lang=args.lang)
# frontend # frontend
frontend = get_frontend(args) frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
# acoustic model # acoustic model
am_inference, am_name, am_dataset = get_am_inference(args, am_config) am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict)
# vocoder # vocoder
voc_inference = get_voc_inference(args, voc_config) voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
# whether dygraph to static # whether dygraph to static
if args.inference_dir: if args.inference_dir:
# acoustic model # acoustic model
am_inference = am_to_static(args, am_inference, am_name, am_dataset) am_inference = am_to_static(
am_inference=am_inference,
am=args.am,
inference_dir=args.inference_dir,
speaker_dict=args.speaker_dict)
# vocoder # vocoder
voc_inference = voc_to_static(args, voc_inference) voc_inference = voc_to_static(
voc_inference=voc_inference,
voc=args.voc,
inference_dir=args.inference_dir)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

@ -49,10 +49,13 @@ def evaluate(args):
print(am_config) print(am_config)
print(voc_config) print(voc_config)
sentences = get_sentences(args) sentences = get_sentences(text_file=args.text, lang=args.lang)
# frontend # frontend
frontend = get_frontend(args) frontend = get_frontend(
lang=args.lang,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict)
with open(args.phones_dict, "r") as f: with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
@ -60,7 +63,6 @@ def evaluate(args):
print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
# acoustic model, only support fastspeech2 here now! # acoustic model, only support fastspeech2 here now!
# am_inference, am_name, am_dataset = get_am_inference(args, am_config)
# model: {model_name}_{dataset} # model: {model_name}_{dataset}
am_name = args.am[:args.am.rindex('_')] am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:] am_dataset = args.am[args.am.rindex('_') + 1:]
@ -80,7 +82,11 @@ def evaluate(args):
am_postnet = am.postnet am_postnet = am.postnet
# vocoder # vocoder
voc_inference = get_voc_inference(args, voc_config) voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
# whether dygraph to static # whether dygraph to static
if args.inference_dir: if args.inference_dir:
@ -115,7 +121,10 @@ def evaluate(args):
os.path.join(args.inference_dir, args.am + "_am_postnet")) os.path.join(args.inference_dir, args.am + "_am_postnet"))
# vocoder # vocoder
voc_inference = voc_to_static(args, voc_inference) voc_inference = voc_to_static(
voc_inference=voc_inference,
voc=args.voc,
inference_dir=args.inference_dir)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

@ -66,10 +66,19 @@ def voice_cloning(args):
print("frontend done!") print("frontend done!")
# acoustic model # acoustic model
am_inference, *_ = get_am_inference(args, am_config) am_inference = get_am_inference(
am=args.am,
am_config=am_config,
am_ckpt=args.am_ckpt,
am_stat=args.am_stat,
phones_dict=args.phones_dict)
# vocoder # vocoder
voc_inference = get_voc_inference(args, voc_config) voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

@ -58,8 +58,7 @@ def main():
else: else:
print("ngpu should >= 0 !") print("ngpu should >= 0 !")
model = WaveRNN( model = WaveRNN(**config["model"])
hop_length=config.n_shift, sample_rate=config.fs, **config["model"])
state_dict = paddle.load(args.checkpoint) state_dict = paddle.load(args.checkpoint)
model.set_state_dict(state_dict["main_params"]) model.set_state_dict(state_dict["main_params"])

@ -91,3 +91,199 @@ class LogSoftmaxWrapper(nn.Layer):
predictions = F.log_softmax(predictions, axis=1) predictions = F.log_softmax(predictions, axis=1)
loss = self.criterion(predictions, targets) / targets.sum() loss = self.criterion(predictions, targets) / targets.sum()
return loss return loss
class NCELoss(nn.Layer):
"""Noise Contrastive Estimation loss funtion
Noise Contrastive Estimation (NCE) is an approximation method that is used to
work around the huge computational cost of large softmax layer.
The basic idea is to convert the prediction problem into classification problem
at training stage. It has been proved that these two criterions converges to
the same minimal point as long as noise distribution is close enough to real one.
NCE bridges the gap between generative models and discriminative models,
rather than simply speedup the softmax layer.
With NCE, you can turn almost anything into posterior with less effort (I think).
Refs:
NCEhttp://www.cs.helsinki.fi/u/ahyvarin/papers/Gutmann10AISTATS.pdf
Thanks: https://github.com/mingen-pan/easy-to-use-NCE-RNN-for-Pytorch/blob/master/nce.py
Examples:
Q = Q_from_tokens(output_dim)
NCELoss(Q)
"""
def __init__(self, Q, noise_ratio=100, Z_offset=9.5):
"""Noise Contrastive Estimation loss funtion
Args:
Q (tensor): prior model, uniform or guassian
noise_ratio (int, optional): noise sampling times. Defaults to 100.
Z_offset (float, optional): scale of post processing the score. Defaults to 9.5.
"""
super(NCELoss, self).__init__()
assert type(noise_ratio) is int
self.Q = paddle.to_tensor(Q, stop_gradient=False)
self.N = self.Q.shape[0]
self.K = noise_ratio
self.Z_offset = Z_offset
def forward(self, output, target):
"""Forward inference
Args:
output (tensor): the model output, which is the input of loss function
"""
output = paddle.reshape(output, [-1, self.N])
B = output.shape[0]
noise_idx = self.get_noise(B)
idx = self.get_combined_idx(target, noise_idx)
P_target, P_noise = self.get_prob(idx, output, sep_target=True)
Q_target, Q_noise = self.get_Q(idx)
loss = self.nce_loss(P_target, P_noise, Q_noise, Q_target)
return loss.mean()
def get_Q(self, idx, sep_target=True):
"""Get prior model of batchsize data
"""
idx_size = idx.size
prob_model = paddle.to_tensor(
self.Q.numpy()[paddle.reshape(idx, [-1]).numpy()])
prob_model = paddle.reshape(prob_model, [idx.shape[0], idx.shape[1]])
if sep_target:
return prob_model[:, 0], prob_model[:, 1:]
else:
return prob_model
def get_prob(self, idx, scores, sep_target=True):
"""Post processing the score of post model(output of nn) of batchsize data
"""
scores = self.get_scores(idx, scores)
scale = paddle.to_tensor([self.Z_offset], dtype='float64')
scores = paddle.add(scores, -scale)
prob = paddle.exp(scores)
if sep_target:
return prob[:, 0], prob[:, 1:]
else:
return prob
def get_scores(self, idx, scores):
"""Get the score of post model(output of nn) of batchsize data
"""
B, N = scores.shape
K = idx.shape[1]
idx_increment = paddle.to_tensor(
N * paddle.reshape(paddle.arange(B), [B, 1]) * paddle.ones([1, K]),
dtype="int64",
stop_gradient=False)
new_idx = idx_increment + idx
new_scores = paddle.index_select(
paddle.reshape(scores, [-1]), paddle.reshape(new_idx, [-1]))
return paddle.reshape(new_scores, [B, K])
def get_noise(self, batch_size, uniform=True):
"""Select noise sample
"""
if uniform:
noise = np.random.randint(self.N, size=self.K * batch_size)
else:
noise = np.random.choice(
self.N, self.K * batch_size, replace=True, p=self.Q.data)
noise = paddle.to_tensor(noise, dtype='int64', stop_gradient=False)
noise_idx = paddle.reshape(noise, [batch_size, self.K])
return noise_idx
def get_combined_idx(self, target_idx, noise_idx):
"""Combined target and noise
"""
target_idx = paddle.reshape(target_idx, [-1, 1])
return paddle.concat((target_idx, noise_idx), 1)
def nce_loss(self, prob_model, prob_noise_in_model, prob_noise,
prob_target_in_noise):
"""Combined the loss of target and noise
"""
def safe_log(tensor):
"""Safe log
"""
EPSILON = 1e-10
return paddle.log(EPSILON + tensor)
model_loss = safe_log(prob_model /
(prob_model + self.K * prob_target_in_noise))
model_loss = paddle.reshape(model_loss, [-1])
noise_loss = paddle.sum(
safe_log((self.K * prob_noise) /
(prob_noise_in_model + self.K * prob_noise)), -1)
noise_loss = paddle.reshape(noise_loss, [-1])
loss = -(model_loss + noise_loss)
return loss
class FocalLoss(nn.Layer):
"""This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=-100):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.size_average = size_average
self.ce = nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction="none")
def forward(self, outputs, targets):
"""Forword inference.
Args:
outputs: input tensor
target: target label tensor
"""
ce_loss = self.ce(outputs, targets)
pt = paddle.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt)**self.gamma * ce_loss
if self.size_average:
return focal_loss.mean()
else:
return focal_loss.sum()
if __name__ == "__main__":
import numpy as np
from paddlespeech.vector.utils.vector_utils import Q_from_tokens
paddle.set_device("cpu")
input_data = paddle.uniform([5, 100], dtype="float64")
label_data = np.random.randint(0, 100, size=(5)).astype(np.int64)
input = paddle.to_tensor(input_data)
label = paddle.to_tensor(label_data)
loss1 = FocalLoss()
loss = loss1.forward(input, label)
print("loss: %.5f" % (loss))
Q = Q_from_tokens(100)
loss2 = NCELoss(Q)
loss = loss2.forward(input, label)
print("loss: %.5f" % (loss))

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle
def get_chunks(seg_dur, audio_id, audio_duration): def get_chunks(seg_dur, audio_id, audio_duration):
@ -30,3 +31,11 @@ def get_chunks(seg_dur, audio_id, audio_duration):
for i in range(num_chunks) for i in range(num_chunks)
] ]
return chunk_lst return chunk_lst
def Q_from_tokens(token_num):
"""Get prior model, data from uniform, would support others(guassian) in future
"""
freq = [1] * token_num
Q = paddle.to_tensor(freq, dtype='float64')
return Q / Q.sum()

@ -63,7 +63,8 @@ include(libsndfile)
# include(boost) # not work # include(boost) # not work
set(boost_SOURCE_DIR ${fc_patch}/boost-src) set(boost_SOURCE_DIR ${fc_patch}/boost-src)
set(BOOST_ROOT ${boost_SOURCE_DIR}) set(BOOST_ROOT ${boost_SOURCE_DIR})
# #find_package(boost REQUIRED PATHS ${BOOST_ROOT}) include_directories(${boost_SOURCE_DIR})
link_directories(${boost_SOURCE_DIR}/stage/lib)
# Eigen # Eigen
include(eigen) include(eigen)
@ -141,4 +142,4 @@ set(DEPS ${DEPS}
set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx) set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx)
add_subdirectory(speechx) add_subdirectory(speechx)
add_subdirectory(examples) add_subdirectory(examples)

@ -2,4 +2,5 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_subdirectory(feat) add_subdirectory(feat)
add_subdirectory(nnet) add_subdirectory(nnet)
add_subdirectory(decoder) add_subdirectory(decoder)
add_subdirectory(websocket)

@ -1,6 +1,6 @@
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../../ SPEECHX_ROOT=$PWD/../../..
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
export LC_AL=C export LC_AL=C
SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat:$SPEECHX_EXAMPLES/ds2_ol/websocket
export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN

@ -87,7 +87,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
ctc-prefix-beam-search-decoder-ol \ ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --params_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$vocb_dir/vocab.txt \ --dict_file=$vocb_dir/vocab.txt \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result --result_wspecifier=ark,t:$data/split${nj}/JOB/result
@ -102,7 +102,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
ctc-prefix-beam-search-decoder-ol \ ctc-prefix-beam-search-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --params_path=$model_dir/avg_1.jit.pdiparams \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--dict_file=$vocb_dir/vocab.txt \ --dict_file=$vocb_dir/vocab.txt \
--lm_path=$lm \ --lm_path=$lm \
@ -129,7 +129,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
wfst-decoder-ol \ wfst-decoder-ol \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --params_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \ --word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \ --graph_path=$graph_dir/TLG.fst --max_active=7500 \

@ -0,0 +1,37 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char
# output
aishell_wav_scp=aishell_test.scp
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
export GLOG_logtostderr=1
# websocket client
websocket_client_main \
--wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.36

@ -0,0 +1,66 @@
#!/bin/bash
set +x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# input
mkdir -p data
data=$PWD/data
ckpt_dir=$data/model
model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/
vocb_dir=$ckpt_dir/data/lang_char/
# output
aishell_wav_scp=aishell_test.scp
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
if [ ! -d $ckpt_dir ]; then
mkdir -p $ckpt_dir
wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz
tar xzfv $ckpt_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir
fi
export GLOG_logtostderr=1
# 3. gen cmvn
cmvn=$PWD/cmvn.ark
cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn
text=$data/test/text
graph_dir=./aishell_graph
if [ ! -d $graph_dir ]; then
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip
unzip aishell_graph.zip
fi
# 5. test websocket server
websocket_server_main \
--cmvn_file=$cmvn \
--model_path=$model_dir/avg_1.jit.pdmodel \
--streaming_chunk=0.1 \
--convert2PCM32=true \
--params_path=$model_dir/avg_1.jit.pdiparams \
--word_symbol_table=$graph_dir/words.txt \
--model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \
--graph_path=$graph_dir/TLG.fst --max_active=7500 \
--acoustic_scale=1.2

@ -17,3 +17,6 @@ add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
add_executable(recognizer_test_main ${CMAKE_CURRENT_SOURCE_DIR}/recognizer_test_main.cc)
target_include_directories(recognizer_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(recognizer_test_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS})

@ -34,12 +34,10 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate, DEFINE_int32(downsampling_rate,
4, 4,
"two CNN(kernel=5) module downsampling rate."); "two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names, DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", "save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names"); "model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
@ -58,12 +56,11 @@ int main(int argc, char* argv[]) {
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_path = FLAGS_model_path;
std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file; std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path; std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "model path: " << model_graph; LOG(INFO) << "model path: " << model_path;
LOG(INFO) << "model param: " << model_params; LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file; LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path; LOG(INFO) << "lm path: " << lm_path;
@ -76,10 +73,9 @@ int main(int argc, char* argv[]) {
ppspeech::CTCBeamSearch decoder(opts); ppspeech::CTCBeamSearch decoder(opts);
ppspeech::ModelOptions model_opts; ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph; model_opts.model_path = model_path;
model_opts.params_path = model_params; model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names; model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names; model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
@ -125,7 +121,6 @@ int main(int argc, char* argv[]) {
if (feature_chunk_size < receptive_field_length) break; if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride; int32 start = chunk_idx * chunk_stride;
int32 end = start + chunk_size;
for (int row_id = 0; row_id < chunk_size; ++row_id) { for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start); kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);

@ -0,0 +1,85 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/recognizer.h"
#include "decoder/param.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
ppspeech::Recognizer recognizer(resource);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int sample_rate = 16000;
float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate;
LOG(INFO) << "sr: " << sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
int32 num_done = 0, num_err = 0;
for (; !wav_reader.Done(); wav_reader.Next()) {
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i);
}
recognizer.Accept(wav_chunk);
if (cur_chunk_size < chunk_sample_size) {
recognizer.SetFinished();
}
recognizer.Decode();
sample_offset += cur_chunk_size;
}
std::string result;
result = recognizer.GetFinalResult();
recognizer.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
}

@ -73,9 +73,9 @@ int main(int argc, char* argv[]) {
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path; LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary; LOG(INFO) << "Binary: " << FLAGS_binary;
} catch (simdjson::simdjson_error& err) { } catch (simdjson::simdjson_error& err) {
LOG(ERR) << err.what(); LOG(ERROR) << err.what();
} }
return 0; return 0;
} }

@ -32,7 +32,6 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn"); DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
@ -66,7 +65,8 @@ int main(int argc, char* argv[]) {
std::unique_ptr<ppspeech::FrontendInterface> cmvn( std::unique_ptr<ppspeech::FrontendInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram))); new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram)));
ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); ppspeech::FeatureCacheOptions feat_cache_opts;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
LOG(INFO) << "feat dim: " << feature_cache.Dim(); LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000; int sample_rate = 16000;

@ -0,0 +1,10 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc)
target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_server_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS})
add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc)
target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(websocket_client_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS})

@ -0,0 +1,82 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "websocket/websocket_client.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(host, "127.0.0.1", "host of websocket server");
DEFINE_int32(port, 201314, "port of websocket server");
DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size");
using kaldi::int16;
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::WebSocketClient client(FLAGS_host, FLAGS_port);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier);
const int sample_rate = 16000;
const float streaming_chunk = FLAGS_streaming_chunk;
const int chunk_sample_size = streaming_chunk * sample_rate;
for (; !wav_reader.Done(); wav_reader.Next()) {
client.SendStartSignal();
std::string utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value();
CHECK_EQ(wave_data.SampFreq(), sample_rate);
int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel);
const int tot_samples = waveform.Dim();
int sample_offset = 0;
while (sample_offset < tot_samples) {
int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset);
std::vector<int16> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk[i] = static_cast<int16>(waveform(sample_offset + i));
}
client.SendBinaryData(wav_chunk.data(),
wav_chunk.size() * sizeof(int16));
sample_offset += cur_chunk_size;
LOG(INFO) << "Send " << cur_chunk_size << " samples";
std::this_thread::sleep_for(
std::chrono::milliseconds(static_cast<int>(1 * 1000)));
if (cur_chunk_size < chunk_sample_size) {
client.SendEndSignal();
}
}
while (!client.Done()) {
}
std::string result = client.GetResult();
LOG(INFO) << "utt: " << utt << " " << result;
client.Join();
return 0;
}
return 0;
}

@ -0,0 +1,30 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "websocket/websocket_server.h"
#include "decoder/param.h"
DEFINE_int32(port, 201314, "websocket listening port");
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure();
ppspeech::WebSocketServer server(FLAGS_port, resource);
LOG(INFO) << "Listening at port " << FLAGS_port;
server.Start();
return 0;
}

@ -30,4 +30,10 @@ include_directories(
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/decoder ${CMAKE_CURRENT_SOURCE_DIR}/decoder
) )
add_subdirectory(decoder) add_subdirectory(decoder)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/websocket
)
add_subdirectory(websocket)

@ -28,8 +28,10 @@
#include <sstream> #include <sstream>
#include <stack> #include <stack>
#include <string> #include <string>
#include <thread>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "base/basic_types.h" #include "base/basic_types.h"

@ -7,5 +7,6 @@ add_library(decoder STATIC
ctc_decoders/path_trie.cpp ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp ctc_decoders/scorer.cpp
ctc_tlg_decoder.cc ctc_tlg_decoder.cc
recognizer.cc
) )
target_link_libraries(decoder PUBLIC kenlm utils fst) target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder)

@ -33,7 +33,6 @@ void TLGDecoder::InitDecoder() {
void TLGDecoder::AdvanceDecode( void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) { const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (!decodable->IsLastFrame(frame_decoded_size_)) { while (!decodable->IsLastFrame(frame_decoded_size_)) {
LOG(INFO) << "num frame decode: " << frame_decoded_size_;
AdvanceDecoding(decodable.get()); AdvanceDecoding(decodable.get());
} }
} }
@ -63,4 +62,4 @@ std::string TLGDecoder::GetFinalBestPath() {
} }
return words; return words;
} }
} }

@ -0,0 +1,94 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
DEFINE_string(cmvn_file, "", "read cmvn");
DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size");
DEFINE_bool(convert2PCM32, true, "audio convert to pcm32");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(params_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");
DEFINE_int32(receptive_field_length,
7,
"receptive field of two CNN(kernel=5) downsampling module.");
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");
namespace ppspeech {
// todo refactor later
FeaturePipelineOptions InitFeaturePipelineOptions() {
FeaturePipelineOptions opts;
opts.cmvn_file = FLAGS_cmvn_file;
opts.linear_spectrogram_opts.streaming_chunk = FLAGS_streaming_chunk;
opts.convert2PCM32 = FLAGS_convert2PCM32;
kaldi::FrameExtractionOptions frame_opts;
frame_opts.frame_length_ms = 20;
frame_opts.frame_shift_ms = 10;
frame_opts.remove_dc_offset = false;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
frame_opts.dither = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
opts.feature_cache_opts.frame_chunk_size = FLAGS_receptive_field_length;
opts.feature_cache_opts.frame_chunk_stride = FLAGS_downsampling_rate;
return opts;
}
ModelOptions InitModelOptions() {
ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
model_opts.params_path = FLAGS_params_path;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.output_names = FLAGS_model_output_names;
return model_opts;
}
TLGDecoderOptions InitDecoderOptions() {
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
return decoder_opts;
}
RecognizerResource InitRecognizerResoure() {
RecognizerResource resource;
resource.acoustic_scale = FLAGS_acoustic_scale;
resource.feature_pipeline_opts = InitFeaturePipelineOptions();
resource.model_opts = InitModelOptions();
resource.tlg_opts = InitDecoderOptions();
return resource;
}
}

@ -0,0 +1,60 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/recognizer.h"
namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector;
using std::unique_ptr;
Recognizer::Recognizer(const RecognizerResource& resource) {
// resource_ = resource;
const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts;
feature_pipeline_.reset(new FeaturePipeline(feature_opts));
std::shared_ptr<PaddleNnet> nnet(new PaddleNnet(resource.model_opts));
BaseFloat ac_scale = resource.acoustic_scale;
decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale));
decoder_.reset(new TLGDecoder(resource.tlg_opts));
input_finished_ = false;
}
void Recognizer::Accept(const Vector<BaseFloat>& waves) {
feature_pipeline_->Accept(waves);
}
void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); }
std::string Recognizer::GetFinalResult() {
return decoder_->GetFinalBestPath();
}
void Recognizer::SetFinished() {
feature_pipeline_->SetFinished();
input_finished_ = true;
}
bool Recognizer::IsFinished() { return input_finished_; }
void Recognizer::Reset() {
feature_pipeline_->Reset();
decodable_->Reset();
decoder_->Reset();
}
} // namespace ppspeech

@ -0,0 +1,59 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor later (SGoat)
#pragma once
#include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
namespace ppspeech {
struct RecognizerResource {
FeaturePipelineOptions feature_pipeline_opts;
ModelOptions model_opts;
TLGDecoderOptions tlg_opts;
// CTCBeamSearchOptions beam_search_opts;
kaldi::BaseFloat acoustic_scale;
RecognizerResource()
: acoustic_scale(1.0),
feature_pipeline_opts(),
model_opts(),
tlg_opts() {}
};
class Recognizer {
public:
explicit Recognizer(const RecognizerResource& resouce);
void Accept(const kaldi::Vector<kaldi::BaseFloat>& waves);
void Decode();
std::string GetFinalResult();
void SetFinished();
bool IsFinished();
void Reset();
private:
// std::shared_ptr<RecognizerResource> resource_;
// RecognizerResource resource_;
std::shared_ptr<FeaturePipeline> feature_pipeline_;
std::shared_ptr<Decodable> decodable_;
std::unique_ptr<TLGDecoder> decoder_;
bool input_finished_;
};
} // namespace ppspeech

@ -6,6 +6,7 @@ add_library(frontend STATIC
linear_spectrogram.cc linear_spectrogram.cc
audio_cache.cc audio_cache.cc
feature_cache.cc feature_cache.cc
feature_pipeline.cc
) )
target_link_libraries(frontend PUBLIC kaldi-matrix) target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common)

@ -41,7 +41,7 @@ void AudioCache::Accept(const VectorBase<BaseFloat>& waves) {
ready_feed_condition_.wait(lock); ready_feed_condition_.wait(lock);
} }
for (size_t idx = 0; idx < waves.Dim(); ++idx) { for (size_t idx = 0; idx < waves.Dim(); ++idx) {
int32 buffer_idx = (idx + offset_) % ring_buffer_.size(); int32 buffer_idx = (idx + offset_ + size_) % ring_buffer_.size();
ring_buffer_[buffer_idx] = waves(idx); ring_buffer_[buffer_idx] = waves(idx);
if (convert2PCM32_) if (convert2PCM32_)
ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx)); ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx));

@ -24,7 +24,7 @@ namespace ppspeech {
class AudioCache : public FrontendInterface { class AudioCache : public FrontendInterface {
public: public:
explicit AudioCache(int buffer_size = 1000 * kint16max, explicit AudioCache(int buffer_size = 1000 * kint16max,
bool convert2PCM32 = false); bool convert2PCM32 = true);
virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves); virtual void Accept(const kaldi::VectorBase<BaseFloat>& waves);

@ -23,10 +23,13 @@ using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
using std::unique_ptr; using std::unique_ptr;
FeatureCache::FeatureCache(int max_size, FeatureCache::FeatureCache(FeatureCacheOptions opts,
unique_ptr<FrontendInterface> base_extractor) { unique_ptr<FrontendInterface> base_extractor) {
max_size_ = max_size; max_size_ = opts.max_size;
frame_chunk_stride_ = opts.frame_chunk_stride;
frame_chunk_size_ = opts.frame_chunk_size;
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim();
} }
void FeatureCache::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) { void FeatureCache::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
@ -44,13 +47,14 @@ bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
while (cache_.empty() && base_extractor_->IsFinished() == false) { while (cache_.empty() && base_extractor_->IsFinished() == false) {
ready_read_condition_.wait(lock); // todo refactor: wait
BaseFloat elapsed = timer.Elapsed() * 1000; // ready_read_condition_.wait(lock);
// todo replace 1.0 with timeout_ int32 elapsed = static_cast<int32>(timer.Elapsed() * 1000);
if (elapsed > 1.0) { // todo replace 1 with timeout_, 1 ms
if (elapsed > 1) {
return false; return false;
} }
usleep(1000); // sleep 1 ms usleep(100); // sleep 0.1 ms
} }
if (cache_.empty()) return false; if (cache_.empty()) return false;
feats->Resize(cache_.front().Dim()); feats->Resize(cache_.front().Dim());
@ -63,25 +67,41 @@ bool FeatureCache::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
// read all data from base_feature_extractor_ into cache_ // read all data from base_feature_extractor_ into cache_
bool FeatureCache::Compute() { bool FeatureCache::Compute() {
// compute and feed // compute and feed
Vector<BaseFloat> feature_chunk; Vector<BaseFloat> feature;
bool result = base_extractor_->Read(&feature_chunk); bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) return false;
int32 joint_len = feature.Dim() + remained_feature_.Dim();
int32 num_chunk =
((joint_len / dim_) - frame_chunk_size_) / frame_chunk_stride_ + 1;
std::unique_lock<std::mutex> lock(mutex_); Vector<BaseFloat> joint_feature(joint_len);
while (cache_.size() >= max_size_) { joint_feature.Range(0, remained_feature_.Dim())
ready_feed_condition_.wait(lock); .CopyFromVec(remained_feature_);
} joint_feature.Range(remained_feature_.Dim(), feature.Dim())
.CopyFromVec(feature);
// feed cache for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
if (feature_chunk.Dim() != 0) { int32 start = chunk_idx * frame_chunk_stride_ * dim_;
Vector<BaseFloat> feature_chunk(frame_chunk_size_ * dim_);
SubVector<BaseFloat> tmp(joint_feature.Data() + start,
frame_chunk_size_ * dim_);
feature_chunk.CopyFromVec(tmp);
std::unique_lock<std::mutex> lock(mutex_);
while (cache_.size() >= max_size_) {
ready_feed_condition_.wait(lock);
}
// feed cache
cache_.push(feature_chunk); cache_.push(feature_chunk);
ready_read_condition_.notify_one();
} }
ready_read_condition_.notify_one(); int32 remained_feature_len =
joint_len - num_chunk * frame_chunk_stride_ * dim_;
remained_feature_.Resize(remained_feature_len);
remained_feature_.CopyFromVec(joint_feature.Range(
frame_chunk_stride_ * num_chunk * dim_, remained_feature_len));
return result; return result;
} }
void Reset() {
// std::lock_guard<std::mutex> lock(mutex_);
return;
}
} // namespace ppspeech } // namespace ppspeech

@ -19,10 +19,18 @@
namespace ppspeech { namespace ppspeech {
struct FeatureCacheOptions {
int32 max_size;
int32 frame_chunk_size;
int32 frame_chunk_stride;
FeatureCacheOptions()
: max_size(kint16max), frame_chunk_size(1), frame_chunk_stride(1) {}
};
class FeatureCache : public FrontendInterface { class FeatureCache : public FrontendInterface {
public: public:
explicit FeatureCache( explicit FeatureCache(
int32 max_size = kint16max, FeatureCacheOptions opts,
std::unique_ptr<FrontendInterface> base_extractor = NULL); std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves // Feed feats or waves
@ -32,12 +40,15 @@ class FeatureCache : public FrontendInterface {
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats); virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// feat dim // feat dim
virtual size_t Dim() const { return base_extractor_->Dim(); } virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { virtual void SetFinished() {
// std::unique_lock<std::mutex> lock(mutex_);
base_extractor_->SetFinished(); base_extractor_->SetFinished();
LOG(INFO) << "set finished";
// read the last chunk data // read the last chunk data
Compute(); Compute();
// ready_feed_condition_.notify_one();
} }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
@ -52,9 +63,13 @@ class FeatureCache : public FrontendInterface {
private: private:
bool Compute(); bool Compute();
int32 dim_;
size_t max_size_; size_t max_size_;
std::unique_ptr<FrontendInterface> base_extractor_; int32 frame_chunk_size_;
int32 frame_chunk_stride_;
kaldi::Vector<kaldi::BaseFloat> remained_feature_;
std::unique_ptr<FrontendInterface> base_extractor_;
std::mutex mutex_; std::mutex mutex_;
std::queue<kaldi::Vector<BaseFloat>> cache_; std::queue<kaldi::Vector<BaseFloat>> cache_;
std::condition_variable ready_feed_condition_; std::condition_variable ready_feed_condition_;

@ -0,0 +1,36 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/audio/feature_pipeline.h"
namespace ppspeech {
using std::unique_ptr;
FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
unique_ptr<FrontendInterface> data_source(
new ppspeech::AudioCache(1000 * kint16max, opts.convert2PCM32));
unique_ptr<FrontendInterface> linear_spectrogram(
new ppspeech::LinearSpectrogram(opts.linear_spectrogram_opts,
std::move(data_source)));
unique_ptr<FrontendInterface> cmvn(
new ppspeech::CMVN(opts.cmvn_file, std::move(linear_spectrogram)));
base_extractor_.reset(
new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn)));
}
} // ppspeech

@ -0,0 +1,57 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor later (SGoat)
#pragma once
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
namespace ppspeech {
struct FeaturePipelineOptions {
std::string cmvn_file;
bool convert2PCM32;
LinearSpectrogramOptions linear_spectrogram_opts;
FeatureCacheOptions feature_cache_opts;
FeaturePipelineOptions()
: cmvn_file(""),
convert2PCM32(false),
linear_spectrogram_opts(),
feature_cache_opts() {}
};
class FeaturePipeline : public FrontendInterface {
public:
explicit FeaturePipeline(const FeaturePipelineOptions& opts);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves) {
base_extractor_->Accept(waves);
}
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
return base_extractor_->Read(feats);
}
virtual size_t Dim() const { return base_extractor_->Dim(); }
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); }
private:
std::unique_ptr<FrontendInterface> base_extractor_;
};
}

@ -52,16 +52,16 @@ bool LinearSpectrogram::Read(Vector<BaseFloat>* feats) {
if (flag == false || input_feats.Dim() == 0) return false; if (flag == false || input_feats.Dim() == 0) return false;
int32 feat_len = input_feats.Dim(); int32 feat_len = input_feats.Dim();
int32 left_len = reminded_wav_.Dim(); int32 left_len = remained_wav_.Dim();
Vector<BaseFloat> waves(feat_len + left_len); Vector<BaseFloat> waves(feat_len + left_len);
waves.Range(0, left_len).CopyFromVec(reminded_wav_); waves.Range(0, left_len).CopyFromVec(remained_wav_);
waves.Range(left_len, feat_len).CopyFromVec(input_feats); waves.Range(left_len, feat_len).CopyFromVec(input_feats);
Compute(waves, feats); Compute(waves, feats);
int32 frame_shift = opts_.frame_opts.WindowShift(); int32 frame_shift = opts_.frame_opts.WindowShift();
int32 num_frames = kaldi::NumFrames(waves.Dim(), opts_.frame_opts); int32 num_frames = kaldi::NumFrames(waves.Dim(), opts_.frame_opts);
int32 left_samples = waves.Dim() - frame_shift * num_frames; int32 left_samples = waves.Dim() - frame_shift * num_frames;
reminded_wav_.Resize(left_samples); remained_wav_.Resize(left_samples);
reminded_wav_.CopyFromVec( remained_wav_.CopyFromVec(
waves.Range(frame_shift * num_frames, left_samples)); waves.Range(frame_shift * num_frames, left_samples));
return true; return true;
} }

@ -25,12 +25,12 @@ struct LinearSpectrogramOptions {
kaldi::FrameExtractionOptions frame_opts; kaldi::FrameExtractionOptions frame_opts;
kaldi::BaseFloat streaming_chunk; // second kaldi::BaseFloat streaming_chunk; // second
LinearSpectrogramOptions() : streaming_chunk(0.36), frame_opts() {} LinearSpectrogramOptions() : streaming_chunk(0.1), frame_opts() {}
void Register(kaldi::OptionsItf* opts) { void Register(kaldi::OptionsItf* opts) {
opts->Register("streaming-chunk", opts->Register("streaming-chunk",
&streaming_chunk, &streaming_chunk,
"streaming chunk size, default: 0.36 sec"); "streaming chunk size, default: 0.1 sec");
frame_opts.Register(opts); frame_opts.Register(opts);
} }
}; };
@ -48,7 +48,7 @@ class LinearSpectrogram : public FrontendInterface {
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { virtual void Reset() {
base_extractor_->Reset(); base_extractor_->Reset();
reminded_wav_.Resize(0); remained_wav_.Resize(0);
} }
private: private:
@ -60,7 +60,7 @@ class LinearSpectrogram : public FrontendInterface {
kaldi::BaseFloat hanning_window_energy_; kaldi::BaseFloat hanning_window_energy_;
LinearSpectrogramOptions opts_; LinearSpectrogramOptions opts_;
std::unique_ptr<FrontendInterface> base_extractor_; std::unique_ptr<FrontendInterface> base_extractor_;
kaldi::Vector<kaldi::BaseFloat> reminded_wav_; kaldi::Vector<kaldi::BaseFloat> remained_wav_;
int chunk_sample_size_; int chunk_sample_size_;
DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram);
}; };

@ -78,7 +78,6 @@ bool Decodable::AdvanceChunk() {
} }
int32 nnet_dim = 0; int32 nnet_dim = 0;
Vector<BaseFloat> inferences; Vector<BaseFloat> inferences;
Matrix<BaseFloat> nnet_cache_tmp;
nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim);
nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim); nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim);
nnet_cache_.CopyRowsFromVec(inferences); nnet_cache_.CopyRowsFromVec(inferences);

Loading…
Cancel
Save