Merge pull request #2015 from zh794390558/endpoint

[server][asr] support endpoint for conformer streaming model
pull/2016/head
YangZhou 3 years ago committed by GitHub
commit 8641608f08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,12 +51,12 @@ repos:
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- id: copyright_checker
name: copyright_checker
entry: python .pre-commit-hooks/copyright-check.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
#- id: copyright_checker
# name: copyright_checker
# entry: python .pre-commit-hooks/copyright-check.hook
# language: system
# files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
# exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0
hooks:

@ -31,6 +31,8 @@ asr_online:
force_yes: True
device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True

@ -30,6 +30,9 @@ asr_online:
decode_method:
force_yes: True
device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True

@ -31,6 +31,8 @@ asr_online:
force_yes: True
device: 'cpu' # cpu or gpu:id
decode_method: "attention_rescoring"
continuous_decoding: True # enable continue decoding when endpoint detected
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True

@ -5,4 +5,5 @@ export CUDA_VISIBLE_DEVICE=0,1,2,3
paddlespeech_server start --config_file conf/punc_application.yaml &> punc.log &
# nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 &
paddlespeech_server start --config_file conf/ws_conformer_application.yaml &> streaming_asr.log &
paddlespeech_server start --config_file conf/ws_conformer_application.yaml &> streaming_asr.log &

@ -9,4 +9,5 @@ paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wa
# read the wav and call streaming and punc service
# If `127.0.0.1` is not accessible, you need to use the actual service IP address.
# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os

@ -145,4 +145,3 @@ for com, info in _commands.items():
name='paddlespeech.{}'.format(com),
description=info[0],
cls='paddlespeech.cli.{}.{}'.format(com, info[1]))

@ -21,12 +21,12 @@ from typing import Union
import numpy as np
import paddle
import yaml
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
from ..executor import BaseExecutor
from ..log import logger
from ..utils import stats_wrapper
from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram
__all__ = ['CLSExecutor']

@ -22,13 +22,13 @@ from typing import Union
import paddle
import soundfile
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from yacs.config import CfgNode
from ..executor import BaseExecutor
from ..log import logger
from ..utils import stats_wrapper
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram
from paddlespeech.vector.io.batch import feature_normalize
from paddlespeech.vector.modules.sid_model import SpeakerIdetification

@ -22,8 +22,7 @@ model_alias = {
# -------------- ASR --------------
# ---------------------------------
"deepspeech2offline": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"],
"deepspeech2online":
["paddlespeech.s2t.models.ds2:DeepSpeech2Model"],
"deepspeech2online": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"],
"conformer": ["paddlespeech.s2t.models.u2:U2Model"],
"conformer_online": ["paddlespeech.s2t.models.u2:U2Model"],
"transformer": ["paddlespeech.s2t.models.u2:U2Model"],

@ -76,7 +76,8 @@ class CTCPrefixScorePD():
last_ids = [yi[-1] for yi in y] # last output label ids
n_bh = len(last_ids) # batch * hyps
n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
self.scoring_num = paddle.shape(scoring_ids)[-1] if scoring_ids is not None else 0
self.scoring_num = paddle.shape(scoring_ids)[
-1] if scoring_ids is not None else 0
# prepare state info
if state is None:
r_prev = paddle.full(

@ -22,11 +22,9 @@ import numpy as np
import paddle
from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.models.ds2 import DeepSpeech2InferModel
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model
from paddlespeech.s2t.training.gradclip import ClipGradByGlobalNormWithLog
@ -238,8 +236,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def __init__(self, config, args):
super().__init__(config, args)
self._text_featurizer = TextFeaturizer(
unit_type=config.unit_type,
vocab=config.vocab_filepath)
unit_type=config.unit_type, vocab=config.vocab_filepath)
self.vocab_list = self._text_featurizer.vocab_list
def ordid2token(self, texts, texts_len):
@ -248,7 +245,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(self._text_featurizer.defeaturize(ids.numpy().tolist()))
trans.append(
self._text_featurizer.defeaturize(ids.numpy().tolist()))
return trans
def compute_metrics(self,

@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from .deepspeech2 import DeepSpeech2InferModel
from .deepspeech2 import DeepSpeech2Model
from paddlespeech.s2t.utils import dynamic_pip_install
import sys
try:
import paddlespeech_ctcdecoders

@ -372,11 +372,15 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box=None,
def forward(self,
audio_chunk,
audio_chunk_lens,
chunk_state_h_box=None,
chunk_state_c_box=None):
if self.encoder.rnn_direction == "forward":
eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder(
audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box)
audio_chunk, audio_chunk_lens, chunk_state_h_box,
chunk_state_c_box)
probs_chunk = self.decoder.softmax(eouts_chunk)
return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box
elif self.encoder.rnn_direction == "bidirect":
@ -392,8 +396,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model):
self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None,
self.encoder.feat_size], #[B, chunk_size, feat_dim]
shape=[None, None, self.encoder.feat_size
], #[B, chunk_size, feat_dim]
dtype='float32'),
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]

@ -90,7 +90,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
def _target_mask(self, ys_in_pad):
ys_mask = ys_in_pad != 0
m = subsequent_mask(paddle.shape(ys_mask)[-1])).unsqueeze(0)
m = subsequent_mask(paddle.shape(ys_mask)[-1]).unsqueeze(0)
return ys_mask.unsqueeze(-2) & m
def forward(self, x: paddle.Tensor, t: paddle.Tensor

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import paddle

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Union
import paddle
@ -22,7 +23,6 @@ from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.modules.loss import CTCLoss
from paddlespeech.s2t.utils import ctc_utils
from paddlespeech.s2t.utils.log import Log
import sys
logger = Log(__name__).getlog()

@ -82,7 +82,8 @@ def pad_sequence(sequences: List[paddle.Tensor],
max_size = paddle.shape(sequences[0])
# (TODO Hui Zhang): slice not supprot `end==start`
# trailing_dims = max_size[1:]
trailing_dims = tuple(max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else ()
trailing_dims = tuple(
max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else ()
max_len = max([s.shape[0] for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims

@ -29,6 +29,7 @@ asr_online:
cfg_path:
decode_method:
force_yes: True
device: # cpu or gpu:id
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'

@ -30,6 +30,8 @@ asr_online:
decode_method:
force_yes: True
device: # cpu or gpu:id
continuous_decoding: True # enable continue decoding when endpoint detected
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True

@ -13,6 +13,7 @@
# limitations under the License.
import os
import sys
from typing import ByteString
from typing import Optional
import numpy as np
@ -30,9 +31,10 @@ from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
from paddlespeech.s2t.utils.tensor_utils import pad_sequence
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoingOpt
from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoint
from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import pcm2float
from paddlespeech.server.utils.paddle_predictor import init_predictor
__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine']
@ -54,24 +56,35 @@ class PaddleASRConnectionHanddler:
self.model_config = asr_engine.executor.config
self.asr_engine = asr_engine
self.init()
self.reset()
def init(self):
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self.model_type = self.asr_engine.executor.model_type
self.sample_rate = self.asr_engine.executor.sample_rate
# tokens to text
self.text_feature = self.asr_engine.executor.text_feature
# extract feat, new only fbank in conformer model
self.preprocess_conf = self.model_config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
# frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, (
self.sample_rate, self.preprocess_conf.process[0]['fs'])
self.frame_shift_in_ms = int(
self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000)
self.continuous_decoding = self.config.get("continuous_decoding", False)
self.init_decoder()
self.reset()
def init_decoder(self):
if "deepspeech2" in self.model_type:
assert self.continuous_decoding is False, "ds2 model not support endpoint"
self.am_predictor = self.asr_engine.executor.am_predictor
# extract feat, new only fbank in conformer model
self.preprocess_conf = self.model_config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
self.decoder = CTCDecoder(
odim=self.model_config.output_dim, # <blank> is in vocab
enc_n_units=self.model_config.rnn_layer_size * 2,
@ -90,142 +103,65 @@ class PaddleASRConnectionHanddler:
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
# frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
elif "conformer" in self.model_type or "transformer" in self.model_type:
# acoustic model
self.model = self.asr_engine.executor.model
self.continuous_decoding = self.config.continuous_decoding
logger.info(f"continue decoding: {self.continuous_decoding}")
# ctc decoding config
self.ctc_decode_config = self.asr_engine.executor.config.decode
self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config)
# extract feat, new only fbank in conformer model
self.preprocess_conf = self.model_config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
# frame window and frame shift, in samples unit
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
# ctc endpoint
self.endpoint_opt = OnlineCTCEndpoingOpt(
frame_shift_in_ms=self.frame_shift_in_ms, blank=0)
self.endpointer = OnlineCTCEndpoint(self.endpoint_opt)
else:
raise ValueError(f"Not supported: {self.model_type}")
def extract_feat(self, samples):
# we compute the elapsed time of first char occuring
# and we record the start time at the first pcm sample arraving
if "deepspeech2online" in self.model_type:
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The connection remain the audio samples: {self.remained_wav.shape}"
)
# fbank
feat = self.preprocessing(self.remained_wav,
**self.preprocess_args)
feat = paddle.to_tensor(
feat, dtype="float32").unsqueeze(axis=0)
if self.cached_feat is None:
self.cached_feat = feat
else:
assert (len(feat.shape) == 3)
assert (len(self.cached_feat.shape) == 3)
self.cached_feat = paddle.concat(
[self.cached_feat, feat], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
# cur frame step
num_frames = feat.shape[1]
self.num_frames += num_frames
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info(
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
)
elif "conformer_online" in self.model_type:
logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
self.num_samples += samples.shape[0]
logger.info(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
)
if len(self.remained_wav) < self.win_length:
# samples not enough for feature window
return 0
def model_reset(self):
if "deepspeech2" in self.model_type:
return
# fbank
x_chunk = self.preprocessing(self.remained_wav,
**self.preprocess_args)
x_chunk = paddle.to_tensor(
x_chunk, dtype="float32").unsqueeze(axis=0)
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
# feature cache
if self.cached_feat is None:
self.cached_feat = x_chunk
else:
assert (len(x_chunk.shape) == 3) # (B,T,D)
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1)
## conformer
# cache for conformer online
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.encoder_out = None
# conformer decoding state
self.offset = 0 # global offset in decoding frame unit
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
## just for record info
self.chunk_num = 0 # global decoding chunk num, not used
# cur frame step
num_frames = x_chunk.shape[1]
def output_reset(self):
## outputs
# partial/ending decoding results
self.result_transcripts = ['']
# token timestamp result
self.word_time_stamp = []
# global frame step
self.num_frames += num_frames
## just for record
self.hyps = []
# update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
# one best timestamp viterbi prob is large.
self.time_stamp = []
logger.info(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
)
logger.info(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}")
else:
raise ValueError(f"not supported: {self.model_type}")
def reset_continuous_decoding(self):
"""
when in continous decoding, reset for next utterance.
"""
self.global_frame_offset = self.num_frames
self.model_reset()
self.searcher.reset()
self.endpointer.reset()
self.output_reset()
def reset(self):
if "deepspeech2" in self.model_type:
@ -241,38 +177,87 @@ class PaddleASRConnectionHanddler:
dtype=float32)
self.decoder.reset_decoder(batch_size=1)
if "conformer" in self.model_type or "transformer" in self.model_type:
self.searcher.reset()
self.endpointer.reset()
self.device = None
## common
# global sample and frame step
self.num_samples = 0
self.global_frame_offset = 0
# frame step of cur utterance
self.num_frames = 0
# cache for audio and feat
self.remained_wav = None
self.cached_feat = None
# partial/ending decoding results
self.result_transcripts = ['']
## endpoint
self.endpoint_state = False # True for detect endpoint
## conformer
self.model_reset()
# cache for conformer online
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.encoder_out = None
# conformer decoding state
self.chunk_num = 0 # globa decoding chunk num
self.offset = 0 # global offset in decoding frame unit
self.hyps = []
## outputs
self.output_reset()
# token timestamp result
self.word_time_stamp = []
def extract_feat(self, samples: ByteString):
logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
# one best timestamp viterbi prob is large.
self.time_stamp = []
self.num_samples += samples.shape[0]
logger.info(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
)
if len(self.remained_wav) < self.win_length:
# samples not enough for feature window
return 0
# fbank
x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args)
x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0)
# feature cache
if self.cached_feat is None:
self.cached_feat = x_chunk
else:
assert (len(x_chunk.shape) == 3) # (B,T,D)
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
# cur frame step
num_frames = x_chunk.shape[1]
# global frame step
self.num_frames += num_frames
# update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
)
logger.info(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}")
def decode(self, is_finished=False):
"""advance decoding
@ -280,14 +265,12 @@ class PaddleASRConnectionHanddler:
Args:
is_finished (bool, optional): Is last frame or not. Defaults to False.
Raises:
Exception: when not support model.
Returns:
None: nothing
None:
"""
if "deepspeech2online" in self.model_type:
if "deepspeech2" in self.model_type:
decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
context = 7 # context=7, in audio frame unit
subsampling = 4 # subsampling=4, in audio frame unit
@ -332,9 +315,11 @@ class PaddleASRConnectionHanddler:
end = None
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
# extract the audio
x_chunk = self.cached_feat[:, cur:end, :].numpy()
x_chunk_lens = np.array([x_chunk.shape[1]])
trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)
self.result_transcripts = [trans_best]
@ -409,31 +394,41 @@ class PaddleASRConnectionHanddler:
@paddle.no_grad()
def advance_decoding(self, is_finished=False):
if "deepspeech" in self.model_type:
return
# reset endpiont state
self.endpoint_state = False
logger.info(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
cfg = self.ctc_decode_config
# cur chunk size, in decoding frame unit
# cur chunk size, in decoding frame unit, e.g. 16
decoding_chunk_size = cfg.decoding_chunk_size
# using num of history chunks
# using num of history chunks, e.g -1
num_decoding_left_chunks = cfg.num_decoding_left_chunks
assert decoding_chunk_size > 0
# e.g. 4
subsampling = self.model.encoder.embed.subsampling_rate
# e.g. 7
context = self.model.encoder.embed.right_context + 1
# processed chunk feature cached for next chunk
# processed chunk feature cached for next chunk, e.g. 3
cached_feature_num = context - subsampling
# decoding stride, in audio frame unit
stride = subsampling * decoding_chunk_size
# decoding window, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context
# decoding stride, in audio frame unit
stride = subsampling * decoding_chunk_size
if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data")
return
# (B=1,T,D)
num_frames = self.cached_feat.shape[1]
logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
@ -454,9 +449,6 @@ class PaddleASRConnectionHanddler:
return None, None
logger.info("start to do model forward")
# hist of chunks, in deocding frame unit
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
outputs = []
# num_frames - context + 1 ensure that current frame can get context window
if is_finished:
@ -466,7 +458,11 @@ class PaddleASRConnectionHanddler:
# we only process decoding_window frames for one chunk
left_frames = decoding_window
# hist of chunks, in deocding frame unit
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
# record the end for removing the processed feat
outputs = []
end = None
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
@ -491,30 +487,40 @@ class PaddleASRConnectionHanddler:
self.encoder_out = ys
else:
self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1)
logger.info(
f"This connection handler encoder out shape: {self.encoder_out.shape}"
)
# get the ctc probs
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
## decoding
# advance decoding
self.searcher.search(ctc_probs, self.cached_feat.place)
# get one best hyps
self.hyps = self.searcher.get_one_best_hyps()
assert self.cached_feat.shape[0] == 1
assert end >= cached_feature_num
# endpoint
if not is_finished:
def contain_nonsilence():
return len(self.hyps) > 0 and len(self.hyps[0]) > 0
decoding_something = contain_nonsilence()
if self.endpointer.endpoint_detected(ctc_probs.numpy(),
decoding_something):
self.endpoint_state = True
logger.info(f"Endpoint is detected at {self.num_frames} frame.")
# advance cache of feat
self.cached_feat = self.cached_feat[0, end -
cached_feature_num:, :].unsqueeze(0)
assert self.cached_feat.shape[0] == 1 #(B=1,T,D)
assert end >= cached_feature_num
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
assert len(
self.cached_feat.shape
) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
logger.info(
f"This connection handler encoder out shape: {self.encoder_out.shape}"
)
def update_result(self):
"""Conformer/Transformer hyps to result.
"""
@ -654,24 +660,28 @@ class PaddleASRConnectionHanddler:
# update each word start and end time stamp
# decoding frame to audio frame
frame_shift = self.model.encoder.embed.subsampling_rate
frame_shift_in_sec = frame_shift * (self.n_shift / self.sample_rate)
logger.info(f"frame shift sec: {frame_shift_in_sec}")
decode_frame_shift = self.model.encoder.embed.subsampling_rate
decode_frame_shift_in_sec = decode_frame_shift * (self.n_shift /
self.sample_rate)
logger.info(f"decode frame shift in sec: {decode_frame_shift_in_sec}")
global_offset_in_sec = self.global_frame_offset * self.frame_shift_in_ms / 1000.0
logger.info(f"global offset: {global_offset_in_sec} sec.")
word_time_stamp = []
for idx, _ in enumerate(self.time_stamp):
start = (self.time_stamp[idx - 1] + self.time_stamp[idx]
) / 2.0 if idx > 0 else 0
start = start * frame_shift_in_sec
start = start * decode_frame_shift_in_sec
end = (self.time_stamp[idx] + self.time_stamp[idx + 1]
) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset
end = end * frame_shift_in_sec
end = end * decode_frame_shift_in_sec
word_time_stamp.append({
"w": self.result_transcripts[0][idx],
"bg": start,
"ed": end
"bg": global_offset_in_sec + start,
"ed": global_offset_in_sec + end
})
# logger.info(f"{word_time_stamp[-1]}")
@ -705,13 +715,14 @@ class ASRServerExecutor(ASRExecutor):
self.model_type = model_type
self.sample_rate = sample_rate
logger.info(f"model_type: {self.model_type}")
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(model_tag=tag)
if cfg_path is None or am_model is None or am_params is None:
logger.info(f"Load the pretrained model, tag = {tag}")
self.res_path = self.task_resource.res_dir
self.cfg_path = os.path.join(
self.res_path, self.task_resource.res_dict['cfg_path'])
@ -719,7 +730,6 @@ class ASRServerExecutor(ASRExecutor):
self.task_resource.res_dict['model'])
self.am_params = os.path.join(self.res_path,
self.task_resource.res_dict['params'])
logger.info(self.res_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model)
@ -727,9 +737,12 @@ class ASRServerExecutor(ASRExecutor):
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.am_model)
logger.info(self.am_params)
logger.info("Load the pretrained model:")
logger.info(f" tag = {tag}")
logger.info(f" res_path: {self.res_path}")
logger.info(f" cfg path: {self.cfg_path}")
logger.info(f" am_model path: {self.am_model}")
logger.info(f" am_params path: {self.am_params}")
#Init body.
self.config = CfgNode(new_allowed=True)
@ -738,25 +751,39 @@ class ASRServerExecutor(ASRExecutor):
if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix)
logger.info(f"spm model path: {self.config.spm_model_prefix}")
self.vocab = self.config.vocab_filepath
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
self.vocab = self.config.vocab_filepath
with UpdateConfig(self.config):
if "deepspeech2" in model_type:
if "deepspeech2" in model_type:
with UpdateConfig(self.config):
# download lm
self.config.decode.lang_model_path = os.path.join(
MODEL_HOME, 'language_model',
self.config.decode.lang_model_path)
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
lm_url = self.task_resource.res_dict['lm_url']
lm_md5 = self.task_resource.res_dict['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm(
lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type:
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
elif "conformer" in model_type or "transformer" in model_type:
with UpdateConfig(self.config):
logger.info("start to create the stream conformer asr engine")
# update the decoding method
if decode_method:
@ -770,37 +797,24 @@ class ASRServerExecutor(ASRExecutor):
logger.info(
"we set the decoding_method to attention_rescoring")
self.config.decode.decoding_method = "attention_rescoring"
assert self.config.decode.decoding_method in [
"ctc_prefix_beam_search", "attention_rescoring"
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
else:
raise Exception("wrong type")
if "deepspeech2" in model_type:
# AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor(
model_file=self.am_model,
params_file=self.am_params,
predictor_conf=self.am_predictor_conf)
elif "conformer" in model_type or "transformer" in model_type:
# load model
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
logger.info(f"model name: {model_name}")
model_class = self.task_resource.get_model_class(model_name)
model_conf = self.config
model = model_class.from_config(model_conf)
model = model_class.from_config(self.config)
self.model = model
self.model.set_state_dict(paddle.load(self.am_model))
self.model.eval()
# load model
model_dict = paddle.load(self.am_model)
self.model.set_state_dict(model_dict)
logger.info("create the transformer like model success")
else:
raise ValueError(f"Not support: {model_type}")
raise Exception(f"not support: {model_type}")
logger.info(f"create the {model_type} model success")
return True
@ -857,6 +871,14 @@ class ASREngine(BaseEngine):
logger.info("Initialize ASR server engine successfully.")
return True
def new_handler(self):
"""New handler from model.
Returns:
PaddleASRConnectionHanddler: asr handler instance
"""
return PaddleASRConnectionHanddler(self)
def preprocess(self, *args, **kwargs):
raise NotImplementedError("Online not using this.")

@ -0,0 +1,118 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import numpy as np
from paddlespeech.cli.log import logger
@dataclass
class OnlineCTCEndpointRule:
must_contain_nonsilence: bool = True
min_trailing_silence: int = 1000
min_utterance_length: int = 0
@dataclass
class OnlineCTCEndpoingOpt:
frame_shift_in_ms: int = 10
blank: int = 0 # blank id, that we consider as silence for purposes of endpointing.
blank_threshold: float = 0.8 # above blank threshold is silence
# We support three rules. We terminate decoding if ANY of these rules
# evaluates to "true". If you want to add more rules, do it by changing this
# code. If you want to disable a rule, you can set the silence-timeout for
# that rule to a very large number.
# rule1 times out after 5 seconds of silence, even if we decoded nothing.
rule1: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 5000, 0)
# rule4 times out after 1.0 seconds of silence after decoding something,
# even if we did not reach a final-state at all.
rule2: OnlineCTCEndpointRule = OnlineCTCEndpointRule(True, 1000, 0)
# rule5 times out after the utterance is 20 seconds long, regardless of
# anything else.
rule3: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 0, 20000)
class OnlineCTCEndpoint:
"""
[END-TO-END AUTOMATIC SPEECH RECOGNITION INTEGRATED WITH CTC-BASED VOICE ACTIVITY DETECTION](https://arxiv.org/pdf/2002.00551.pdf)
"""
def __init__(self, opts: OnlineCTCEndpoingOpt):
self.opts = opts
logger.info(f"Endpont Opts: {opts}")
self.frame_shift_in_ms = opts.frame_shift_in_ms
self.num_frames_decoded = 0
self.trailing_silence_frames = 0
self.reset()
def reset(self):
self.num_frames_decoded = 0
self.trailing_silence_frames = 0
def rule_activated(self,
rule: OnlineCTCEndpointRule,
rule_name: str,
decoding_something: bool,
trailine_silence: int,
utterance_length: int) -> bool:
ans = (
decoding_something or (not rule.must_contain_nonsilence)
) and trailine_silence >= rule.min_trailing_silence and utterance_length >= rule.min_utterance_length
if (ans):
logger.info(f"Endpoint Rule: {rule_name} activated: {rule}")
return ans
def endpoint_detected(self,
ctc_log_probs: np.ndarray,
decoding_something: bool) -> bool:
"""detect endpoint.
Args:
ctc_log_probs (np.ndarray): (T, D)
decoding_something (bool): contain nonsilince.
Returns:
bool: whether endpoint detected.
"""
for logprob in ctc_log_probs:
blank_prob = np.exp(logprob[self.opts.blank])
self.num_frames_decoded += 1
if blank_prob > self.opts.blank_threshold:
self.trailing_silence_frames += 1
else:
self.trailing_silence_frames = 0
assert self.num_frames_decoded >= self.trailing_silence_frames
assert self.frame_shift_in_ms > 0
utterance_length = self.num_frames_decoded * self.frame_shift_in_ms
trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms
if self.rule_activated(self.opts.rule1, 'rule1', decoding_something,
trailing_silence, utterance_length):
return True
if self.rule_activated(self.opts.rule2, 'rule2', decoding_something,
trailing_silence, utterance_length):
return True
if self.rule_activated(self.opts.rule3, 'rule3', decoding_something,
trailing_silence, utterance_length):
return True
return False

@ -30,8 +30,29 @@ class CTCPrefixBeamSearch:
config (yacs.config.CfgNode): the ctc prefix beam search configuration
"""
self.config = config
# beam size
self.first_beam_size = self.config.beam_size
# TODO(support second beam size)
self.second_beam_size = int(self.first_beam_size * 1.0)
logger.info(
f"first and second beam size: {self.first_beam_size}, {self.second_beam_size}"
)
# state
self.cur_hyps = None
self.hyps = None
self.abs_time_step = 0
self.reset()
def reset(self):
"""Rest the search cache value
"""
self.cur_hyps = None
self.hyps = None
self.abs_time_step = 0
@paddle.no_grad()
def search(self, ctc_probs, device, blank_id=0):
"""ctc prefix beam search method decode a chunk feature
@ -47,12 +68,17 @@ class CTCPrefixBeamSearch:
"""
# decode
logger.info("start to ctc prefix search")
assert len(ctc_probs.shape) == 2
batch_size = 1
beam_size = self.config.beam_size
maxlen = ctc_probs.shape[0]
assert len(ctc_probs.shape) == 2
vocab_size = ctc_probs.shape[1]
first_beam_size = min(self.first_beam_size, vocab_size)
second_beam_size = min(self.second_beam_size, vocab_size)
logger.info(
f"effect first and second beam size: {self.first_beam_size}, {self.second_beam_size}"
)
maxlen = ctc_probs.shape[0]
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# 0. blank_ending_score,
@ -75,7 +101,8 @@ class CTCPrefixBeamSearch:
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
top_k_logp, top_k_index = logp.topk(
first_beam_size) # (first_beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
@ -148,7 +175,7 @@ class CTCPrefixBeamSearch:
next_hyps.items(),
key=lambda x: log_add([x[1][0], x[1][1]]),
reverse=True)
self.cur_hyps = next_hyps[:beam_size]
self.cur_hyps = next_hyps[:second_beam_size]
# 2.3 update the absolute time step
self.abs_time_step += 1
@ -163,7 +190,7 @@ class CTCPrefixBeamSearch:
"""Return the one best result
Returns:
list: the one best result
list: the one best result, List[str]
"""
return [self.hyps[0][0]]
@ -171,17 +198,10 @@ class CTCPrefixBeamSearch:
"""Return the search hyps
Returns:
list: return the search hyps
list: return the search hyps, List[Tuple[str, float, ...]]
"""
return self.hyps
def reset(self):
"""Rest the search cache value
"""
self.cur_hyps = None
self.hyps = None
self.abs_time_step = 0
def finalize_search(self):
"""do nothing in ctc_prefix_beam_search
"""

@ -42,7 +42,6 @@ class TTSServerExecutor(TTSExecutor):
self.task_resource = CommonTaskResource(
task='tts', model_format='dynamic', inference_mode='online')
def get_model_info(self,
field: str,
model_name: str,

@ -19,7 +19,6 @@ from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
from paddlespeech.server.engine.engine_pool import get_engine_pool
router = APIRouter()
@ -38,7 +37,7 @@ async def websocket_endpoint(websocket: WebSocket):
#2. if we accept the websocket headers, we will get the online asr engine instance
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
asr_model = engine_pool['asr']
#3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio
# and each connection has its own connection instance to process the request
@ -70,7 +69,8 @@ async def websocket_endpoint(websocket: WebSocket):
resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here
# create the instance to process the audio
connection_handler = PaddleASRConnectionHanddler(asr_engine)
#connection_handler = PaddleASRConnectionHanddler(asr_model)
connection_handler = asr_model.new_handler()
await websocket.send_json(resp)
elif message['signal'] == 'end':
# reset single engine for an new connection
@ -100,11 +100,34 @@ async def websocket_endpoint(websocket: WebSocket):
# and decode for the result in this package data
connection_handler.extract_feat(message)
connection_handler.decode(is_finished=False)
if connection_handler.endpoint_state:
logger.info("endpoint: detected and rescoring.")
connection_handler.rescoring()
word_time_stamp = connection_handler.get_word_time_stamp()
asr_results = connection_handler.get_result()
# return the current period result
# if the engine create the vad instance, this connection will have many period results
if connection_handler.endpoint_state:
if connection_handler.continuous_decoding:
logger.info("endpoint: continue decoding")
connection_handler.reset_continuous_decoding()
else:
logger.info("endpoint: exit decoding")
# ending by endpoint
resp = {
"status": "ok",
"signal": "finished",
'result': asr_results,
'times': word_time_stamp
}
await websocket.send_json(resp)
break
# return the current partial result
# if the engine create the vad instance, this connection will have many partial results
resp = {'result': asr_results}
await websocket.send_json(resp)
except WebSocketDisconnect as e:
logger.error(e)

@ -140,10 +140,7 @@ def parse_args():
],
help='Choose acoustic model type of tts task.')
parser.add_argument(
'--am_config',
type=str,
default=None,
help='Config of acoustic model.')
'--am_config', type=str, default=None, help='Config of acoustic model.')
parser.add_argument(
'--am_ckpt',
type=str,
@ -179,10 +176,7 @@ def parse_args():
],
help='Choose vocoder type of tts task.')
parser.add_argument(
'--voc_config',
type=str,
default=None,
help='Config of voc.')
'--voc_config', type=str, default=None, help='Config of voc.')
parser.add_argument(
'--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
parser.add_argument(

@ -174,10 +174,7 @@ def parse_args():
],
help='Choose acoustic model type of tts task.')
parser.add_argument(
'--am_config',
type=str,
default=None,
help='Config of acoustic model.')
'--am_config', type=str, default=None, help='Config of acoustic model.')
parser.add_argument(
'--am_ckpt',
type=str,
@ -220,10 +217,7 @@ def parse_args():
],
help='Choose vocoder type of tts task.')
parser.add_argument(
'--voc_config',
type=str,
default=None,
help='Config of voc.')
'--voc_config', type=str, default=None, help='Config of voc.')
parser.add_argument(
'--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
parser.add_argument(

@ -131,10 +131,7 @@ def parse_args():
choices=['fastspeech2_aishell3', 'tacotron2_aishell3'],
help='Choose acoustic model type of tts task.')
parser.add_argument(
'--am_config',
type=str,
default=None,
help='Config of acoustic model.')
'--am_config', type=str, default=None, help='Config of acoustic model.')
parser.add_argument(
'--am_ckpt',
type=str,
@ -160,10 +157,7 @@ def parse_args():
help='Choose vocoder type of tts task.')
parser.add_argument(
'--voc_config',
type=str,
default=None,
help='Config of voc.')
'--voc_config', type=str, default=None, help='Config of voc.')
parser.add_argument(
'--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
parser.add_argument(

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .vits import *
from .vits_updater import *
from .vits_updater import *

@ -56,7 +56,8 @@ class VITSUpdater(StandardUpdater):
self.models: Dict[str, Layer] = models
# self.model = model
self.model = model._layers if isinstance(model, paddle.DataParallel) else model
self.model = model._layers if isinstance(model,
paddle.DataParallel) else model
self.optimizers = optimizers
self.optimizer_g: Optimizer = optimizers['generator']
@ -225,7 +226,8 @@ class VITSEvaluator(StandardEvaluator):
models = {"main": model}
self.models: Dict[str, Layer] = models
# self.model = model
self.model = model._layers if isinstance(model, paddle.DataParallel) else model
self.model = model._layers if isinstance(model,
paddle.DataParallel) else model
self.criterions = criterions
self.criterion_mel = criterions['mel']

@ -971,18 +971,18 @@ class FeatureMatchLoss(nn.Layer):
return feat_match_loss
# loss for VITS
class KLDivergenceLoss(nn.Layer):
"""KL divergence loss."""
def forward(
self,
z_p: paddle.Tensor,
logs_q: paddle.Tensor,
m_p: paddle.Tensor,
logs_p: paddle.Tensor,
z_mask: paddle.Tensor,
) -> paddle.Tensor:
self,
z_p: paddle.Tensor,
logs_q: paddle.Tensor,
m_p: paddle.Tensor,
logs_p: paddle.Tensor,
z_mask: paddle.Tensor, ) -> paddle.Tensor:
"""Calculate KL divergence loss.
Args:
@ -1002,8 +1002,8 @@ class KLDivergenceLoss(nn.Layer):
logs_p = paddle.cast(logs_p, 'float32')
z_mask = paddle.cast(z_mask, 'float32')
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * paddle.exp(-2.0 * logs_p)
kl += 0.5 * ((z_p - m_p)**2) * paddle.exp(-2.0 * logs_p)
kl = paddle.sum(kl * z_mask)
loss = kl / paddle.sum(z_mask)
return loss
return loss

@ -25,4 +25,3 @@ netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host
> Reminder: Only for developer, make sure you know what's it.
* codelab - for speechx developer, using for test.

@ -3,4 +3,4 @@
## Examples
* `websocket` - Streaming ASR with websocket for deepspeech2_aishell.
* `aishell` - Streaming Decoding under aishell dataset, for local WER test.
* `aishell` - Streaming Decoding under aishell dataset, for local WER test.

@ -4,4 +4,3 @@
> Reminder: Only for developer.
* codelab - for speechx developer, using for test.

@ -91,8 +91,8 @@ int main(int argc, char* argv[]) {
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
int32 chunk_size = FLAGS_receptive_field_length
+ (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate;
int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;

@ -64,7 +64,7 @@ std::string TLGDecoder::GetPartialResult() {
std::string word = word_symbol_table_->Find(words_id[idx]);
words += word;
}
return words;
return words;
}
std::string TLGDecoder::GetFinalBestPath() {

@ -82,7 +82,7 @@ FeaturePipelineOptions InitFeaturePipelineOptions() {
opts.assembler_opts.subsampling_rate = FLAGS_downsampling_rate;
opts.assembler_opts.receptive_filed_length = FLAGS_receptive_field_length;
opts.assembler_opts.nnet_decoder_chunk = FLAGS_nnet_decoder_chunk;
return opts;
}

@ -93,8 +93,8 @@ int main(int argc, char* argv[]) {
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = FLAGS_receptive_field_length
+ (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate;
int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;

@ -24,7 +24,8 @@ using std::unique_ptr;
Assembler::Assembler(AssemblerOptions opts,
unique_ptr<FrontendInterface> base_extractor) {
frame_chunk_stride_ = opts.subsampling_rate * opts.nnet_decoder_chunk;
frame_chunk_size_ = (opts.nnet_decoder_chunk - 1) * opts.subsampling_rate + opts.receptive_filed_length;
frame_chunk_size_ = (opts.nnet_decoder_chunk - 1) * opts.subsampling_rate +
opts.receptive_filed_length;
receptive_filed_length_ = opts.receptive_filed_length;
base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim();
@ -50,8 +51,8 @@ bool Assembler::Compute(Vector<BaseFloat>* feats) {
Vector<BaseFloat> feature;
result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) {
if (IsFinished() == false) return false;
break;
if (IsFinished() == false) return false;
break;
}
feature_cache_.push(feature);
}
@ -61,22 +62,22 @@ bool Assembler::Compute(Vector<BaseFloat>* feats) {
}
while (feature_cache_.size() < frame_chunk_size_) {
Vector<BaseFloat> feature(dim_, kaldi::kSetZero);
feature_cache_.push(feature);
Vector<BaseFloat> feature(dim_, kaldi::kSetZero);
feature_cache_.push(feature);
}
int32 counter = 0;
int32 counter = 0;
int32 cache_size = frame_chunk_size_ - frame_chunk_stride_;
int32 elem_dim = base_extractor_->Dim();
while (counter < frame_chunk_size_) {
Vector<BaseFloat>& val = feature_cache_.front();
int32 start = counter * elem_dim;
feats->Range(start, elem_dim).CopyFromVec(val);
if (frame_chunk_size_ - counter <= cache_size ) {
feature_cache_.push(val);
}
feature_cache_.pop();
counter++;
Vector<BaseFloat>& val = feature_cache_.front();
int32 start = counter * elem_dim;
feats->Range(start, elem_dim).CopyFromVec(val);
if (frame_chunk_size_ - counter <= cache_size) {
feature_cache_.push(val);
}
feature_cache_.pop();
counter++;
}
return result;

@ -25,7 +25,7 @@ struct AssemblerOptions {
int32 receptive_filed_length;
int32 subsampling_rate;
int32 nnet_decoder_chunk;
AssemblerOptions()
: receptive_filed_length(1),
subsampling_rate(1),
@ -47,15 +47,11 @@ class Assembler : public FrontendInterface {
// feat dim
virtual size_t Dim() const { return dim_; }
virtual void SetFinished() {
base_extractor_->SetFinished();
}
virtual void SetFinished() { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() {
base_extractor_->Reset();
}
virtual void Reset() { base_extractor_->Reset(); }
private:
bool Compute(kaldi::Vector<kaldi::BaseFloat>* feats);

@ -30,7 +30,7 @@ class AudioCache : public FrontendInterface {
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* waves);
// the audio dim is 1, one sample, which is useless,
// the audio dim is 1, one sample, which is useless,
// so we return size_(cache samples) instead.
virtual size_t Dim() const { return size_; }

@ -29,19 +29,19 @@ using kaldi::Matrix;
using std::vector;
FbankComputer::FbankComputer(const Options& opts)
: opts_(opts),
computer_(opts) {}
: opts_(opts), computer_(opts) {}
int32 FbankComputer::Dim() const {
return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
}
bool FbankComputer::NeedRawLogEnergy() {
return opts_.use_energy && opts_.raw_energy;
return opts_.use_energy && opts_.raw_energy;
}
// Compute feat
bool FbankComputer::Compute(Vector<BaseFloat>* window, Vector<BaseFloat>* feat) {
bool FbankComputer::Compute(Vector<BaseFloat>* window,
Vector<BaseFloat>* feat) {
RealFft(window, true);
kaldi::ComputePowerSpectrum(window);
const kaldi::MelBanks& mel_bank = *(computer_.GetMelBanks(1.0));

@ -72,9 +72,9 @@ bool FeatureCache::Compute() {
bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) return false;
int32 num_chunk = feature.Dim() / dim_ ;
int32 num_chunk = feature.Dim() / dim_;
for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
int32 start = chunk_idx * dim_;
int32 start = chunk_idx * dim_;
Vector<BaseFloat> feature_chunk(dim_);
SubVector<BaseFloat> tmp(feature.Data() + start, dim_);
feature_chunk.CopyFromVec(tmp);

@ -22,9 +22,7 @@ namespace ppspeech {
struct FeatureCacheOptions {
int32 max_size;
int32 timeout; // ms
FeatureCacheOptions()
: max_size(kint16max),
timeout(1) {}
FeatureCacheOptions() : max_size(kint16max), timeout(1) {}
};
class FeatureCache : public FrontendInterface {

@ -23,11 +23,11 @@ template <class F>
class StreamingFeatureTpl : public FrontendInterface {
public:
typedef typename F::Options Options;
StreamingFeatureTpl(const Options& opts,
StreamingFeatureTpl(const Options& opts,
std::unique_ptr<FrontendInterface> base_extractor);
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves);
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats);
// the dim_ is the dim of single frame feature
virtual size_t Dim() const { return computer_.Dim(); }
@ -39,8 +39,9 @@ class StreamingFeatureTpl : public FrontendInterface {
base_extractor_->Reset();
remained_wav_.Resize(0);
}
private:
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& waves,
bool Compute(const kaldi::Vector<kaldi::BaseFloat>& waves,
kaldi::Vector<kaldi::BaseFloat>* feats);
Options opts_;
std::unique_ptr<FrontendInterface> base_extractor_;

@ -16,16 +16,15 @@
namespace ppspeech {
template <class F>
StreamingFeatureTpl<F>::StreamingFeatureTpl(const Options& opts,
std::unique_ptr<FrontendInterface> base_extractor):
opts_(opts),
computer_(opts),
window_function_(opts.frame_opts) {
StreamingFeatureTpl<F>::StreamingFeatureTpl(
const Options& opts, std::unique_ptr<FrontendInterface> base_extractor)
: opts_(opts), computer_(opts), window_function_(opts.frame_opts) {
base_extractor_ = std::move(base_extractor);
}
template <class F>
void StreamingFeatureTpl<F>::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& waves) {
void StreamingFeatureTpl<F>::Accept(
const kaldi::VectorBase<kaldi::BaseFloat>& waves) {
base_extractor_->Accept(waves);
}
@ -58,8 +57,9 @@ bool StreamingFeatureTpl<F>::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
// Compute feat
template <class F>
bool StreamingFeatureTpl<F>::Compute(const kaldi::Vector<kaldi::BaseFloat>& waves,
kaldi::Vector<kaldi::BaseFloat>* feats) {
bool StreamingFeatureTpl<F>::Compute(
const kaldi::Vector<kaldi::BaseFloat>& waves,
kaldi::Vector<kaldi::BaseFloat>* feats) {
const kaldi::FrameExtractionOptions& frame_opts =
computer_.GetFrameOptions();
int32 num_samples = waves.Dim();
@ -84,9 +84,11 @@ bool StreamingFeatureTpl<F>::Compute(const kaldi::Vector<kaldi::BaseFloat>& wave
&window,
need_raw_log_energy ? &raw_log_energy : NULL);
kaldi::Vector<kaldi::BaseFloat> this_feature(computer_.Dim(), kaldi::kUndefined);
kaldi::Vector<kaldi::BaseFloat> this_feature(computer_.Dim(),
kaldi::kUndefined);
computer_.Compute(&window, &this_feature);
kaldi::SubVector<kaldi::BaseFloat> output_row(feats->Data() + frame * Dim(), Dim());
kaldi::SubVector<kaldi::BaseFloat> output_row(
feats->Data() + frame * Dim(), Dim());
output_row.CopyFromVec(this_feature);
}
return true;

@ -16,6 +16,7 @@
#pragma once
#include "frontend/audio/assembler.h"
#include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/fbank.h"
@ -23,7 +24,6 @@
#include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h"
#include "frontend/audio/assembler.h"
namespace ppspeech {

@ -28,22 +28,21 @@ using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector;
LinearSpectrogramComputer::LinearSpectrogramComputer(
const Options& opts)
LinearSpectrogramComputer::LinearSpectrogramComputer(const Options& opts)
: opts_(opts) {
kaldi::FeatureWindowFunction feature_window_function(opts.frame_opts);
int32 window_size = opts.frame_opts.WindowSize();
frame_length_ = window_size;
dim_ = window_size / 2 + 1;
BaseFloat hanning_window_energy = kaldi::VecVec(feature_window_function.window,
feature_window_function.window);
BaseFloat hanning_window_energy = kaldi::VecVec(
feature_window_function.window, feature_window_function.window);
int32 sample_rate = opts.frame_opts.samp_freq;
scale_ = 2.0 / (hanning_window_energy * sample_rate);
}
// Compute spectrogram feat
bool LinearSpectrogramComputer::Compute(Vector<BaseFloat>* window,
Vector<BaseFloat>* feat) {
Vector<BaseFloat>* feat) {
window->Resize(frame_length_, kaldi::kCopyData);
RealFft(window, true);
kaldi::ComputePowerSpectrum(window);

@ -14,8 +14,8 @@
#include "base/flags.h"
#include "base/log.h"
#include "frontend/audio/data_cache.h"
#include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
@ -75,8 +75,8 @@ int main(int argc, char* argv[]) {
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = FLAGS_receptive_field_length
+ (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate;
int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
@ -130,7 +130,9 @@ int main(int argc, char* argv[]) {
vector<kaldi::BaseFloat> prob;
while (decodable->FrameLikelihood(frame_idx, &prob)) {
kaldi::Vector<kaldi::BaseFloat> vec_tmp(prob.size());
std::memcpy(vec_tmp.Data(), prob.data(), sizeof(kaldi::BaseFloat)*prob.size());
std::memcpy(vec_tmp.Data(),
prob.data(),
sizeof(kaldi::BaseFloat) * prob.size());
prob_vec.push_back(vec_tmp);
frame_idx++;
}
@ -142,7 +144,8 @@ int main(int argc, char* argv[]) {
KALDI_LOG << " the nnet prob of " << utt << " is empty";
continue;
}
kaldi::Matrix<kaldi::BaseFloat> result(prob_vec.size(),prob_vec[0].Dim());
kaldi::Matrix<kaldi::BaseFloat> result(prob_vec.size(),
prob_vec[0].Dim());
for (int32 row_idx = 0; row_idx < prob_vec.size(); ++row_idx) {
for (int32 col_idx = 0; col_idx < prob_vec[0].Dim(); ++col_idx) {
result(row_idx, col_idx) = prob_vec[row_idx](col_idx);

@ -40,8 +40,8 @@ class WebSocketClient {
void SendEndSignal();
void SendDataEnd();
bool Done() const { return done_; }
std::string GetResult() const { return result_; }
std::string GetPartialResult() const { return partial_result_;}
std::string GetResult() const { return result_; }
std::string GetPartialResult() const { return partial_result_; }
private:
void Connect();

@ -76,9 +76,10 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) {
recognizer_->Accept(pcm_data);
std::string partial_result = recognizer_->GetPartialResult();
json::value rv = {
{"status", "ok"}, {"type", "partial_result"}, {"result", partial_result}};
json::value rv = {{"status", "ok"},
{"type", "partial_result"},
{"result", partial_result}};
ws_.text(true);
ws_.write(asio::buffer(json::serialize(rv)));
}

Loading…
Cancel
Save