diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e3cc36e00..6e7ae1fbf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,12 +51,12 @@ repos: language: system files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ - - id: copyright_checker - name: copyright_checker - entry: python .pre-commit-hooks/copyright-check.hook - language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ - exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ + #- id: copyright_checker + # name: copyright_checker + # entry: python .pre-commit-hooks/copyright-check.hook + # language: system + # files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ + # exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ - repo: https://github.com/asottile/reorder_python_imports rev: v2.4.0 hooks: diff --git a/demos/streaming_asr_server/conf/application.yaml b/demos/streaming_asr_server/conf/application.yaml index e9a89c19d..683d86f03 100644 --- a/demos/streaming_asr_server/conf/application.yaml +++ b/demos/streaming_asr_server/conf/application.yaml @@ -31,6 +31,8 @@ asr_online: force_yes: True device: 'cpu' # cpu or gpu:id decode_method: "attention_rescoring" + continuous_decoding: True # enable continue decoding when endpoint detected + am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/demos/streaming_asr_server/conf/ws_conformer_application.yaml b/demos/streaming_asr_server/conf/ws_conformer_application.yaml index 6a10741bd..9dbc82b6f 100644 --- a/demos/streaming_asr_server/conf/ws_conformer_application.yaml +++ b/demos/streaming_asr_server/conf/ws_conformer_application.yaml @@ -30,6 +30,9 @@ asr_online: decode_method: force_yes: True device: 'cpu' # cpu or gpu:id + decode_method: "attention_rescoring" + continuous_decoding: True # enable continue decoding when endpoint detected + am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml b/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml index e9a89c19d..683d86f03 100644 --- a/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml +++ b/demos/streaming_asr_server/conf/ws_conformer_wenetspeech_application.yaml @@ -31,6 +31,8 @@ asr_online: force_yes: True device: 'cpu' # cpu or gpu:id decode_method: "attention_rescoring" + continuous_decoding: True # enable continue decoding when endpoint detected + am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/demos/streaming_asr_server/server.sh b/demos/streaming_asr_server/server.sh index 4266f8c64..0d255807c 100755 --- a/demos/streaming_asr_server/server.sh +++ b/demos/streaming_asr_server/server.sh @@ -5,4 +5,5 @@ export CUDA_VISIBLE_DEVICE=0,1,2,3 paddlespeech_server start --config_file conf/punc_application.yaml &> punc.log & # nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 & -paddlespeech_server start --config_file conf/ws_conformer_application.yaml &> streaming_asr.log & \ No newline at end of file +paddlespeech_server start --config_file conf/ws_conformer_application.yaml &> streaming_asr.log & + diff --git a/demos/streaming_asr_server/test.sh b/demos/streaming_asr_server/test.sh index 4f43c6534..f3075454d 100755 --- a/demos/streaming_asr_server/test.sh +++ b/demos/streaming_asr_server/test.sh @@ -9,4 +9,5 @@ paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --input ./zh.wa # read the wav and call streaming and punc service # If `127.0.0.1` is not accessible, you need to use the actual service IP address. # python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav -paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav \ No newline at end of file +paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --input ./zh.wav + diff --git a/examples/wenetspeech/asr1/local/extract_meta.py b/examples/wenetspeech/asr1/local/extract_meta.py index 2cad977be..954dbd780 100644 --- a/examples/wenetspeech/asr1/local/extract_meta.py +++ b/examples/wenetspeech/asr1/local/extract_meta.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import argparse import json import os diff --git a/paddlespeech/cli/base_commands.py b/paddlespeech/cli/base_commands.py index 39bf24524..f5e2246d8 100644 --- a/paddlespeech/cli/base_commands.py +++ b/paddlespeech/cli/base_commands.py @@ -145,4 +145,3 @@ for com, info in _commands.items(): name='paddlespeech.{}'.format(com), description=info[0], cls='paddlespeech.cli.{}.{}'.format(com, info[1])) - \ No newline at end of file diff --git a/paddlespeech/cli/cls/infer.py b/paddlespeech/cli/cls/infer.py index 1a9949748..d31379b88 100644 --- a/paddlespeech/cli/cls/infer.py +++ b/paddlespeech/cli/cls/infer.py @@ -21,12 +21,12 @@ from typing import Union import numpy as np import paddle import yaml +from paddleaudio import load +from paddleaudio.features import LogMelSpectrogram from ..executor import BaseExecutor from ..log import logger from ..utils import stats_wrapper -from paddleaudio import load -from paddleaudio.features import LogMelSpectrogram __all__ = ['CLSExecutor'] diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 56f86f9b8..d049ba7da 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -22,13 +22,13 @@ from typing import Union import paddle import soundfile +from paddleaudio.backends import load as load_audio +from paddleaudio.compliance.librosa import melspectrogram from yacs.config import CfgNode from ..executor import BaseExecutor from ..log import logger from ..utils import stats_wrapper -from paddleaudio.backends import load as load_audio -from paddleaudio.compliance.librosa import melspectrogram from paddlespeech.vector.io.batch import feature_normalize from paddlespeech.vector.modules.sid_model import SpeakerIdetification diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py index 0aeda13d4..5309fd86f 100644 --- a/paddlespeech/resource/model_alias.py +++ b/paddlespeech/resource/model_alias.py @@ -22,8 +22,7 @@ model_alias = { # -------------- ASR -------------- # --------------------------------- "deepspeech2offline": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"], - "deepspeech2online": - ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"], + "deepspeech2online": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"], "conformer": ["paddlespeech.s2t.models.u2:U2Model"], "conformer_online": ["paddlespeech.s2t.models.u2:U2Model"], "transformer": ["paddlespeech.s2t.models.u2:U2Model"], diff --git a/paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py b/paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py index d8ca5ccde..a994412e0 100644 --- a/paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py +++ b/paddlespeech/s2t/decoders/scorers/ctc_prefix_score.py @@ -76,7 +76,8 @@ class CTCPrefixScorePD(): last_ids = [yi[-1] for yi in y] # last output label ids n_bh = len(last_ids) # batch * hyps n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps - self.scoring_num = paddle.shape(scoring_ids)[-1] if scoring_ids is not None else 0 + self.scoring_num = paddle.shape(scoring_ids)[ + -1] if scoring_ids is not None else 0 # prepare state info if state is None: r_prev = paddle.full( diff --git a/paddlespeech/s2t/exps/deepspeech2/model.py b/paddlespeech/s2t/exps/deepspeech2/model.py index 2cd3c04ac..511997a7c 100644 --- a/paddlespeech/s2t/exps/deepspeech2/model.py +++ b/paddlespeech/s2t/exps/deepspeech2/model.py @@ -22,11 +22,9 @@ import numpy as np import paddle from paddle import distributed as dist from paddle import inference -from paddle.io import DataLoader -from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.io.dataset import ManifestDataset +from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.models.ds2 import DeepSpeech2InferModel from paddlespeech.s2t.models.ds2 import DeepSpeech2Model from paddlespeech.s2t.training.gradclip import ClipGradByGlobalNormWithLog @@ -238,8 +236,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def __init__(self, config, args): super().__init__(config, args) self._text_featurizer = TextFeaturizer( - unit_type=config.unit_type, - vocab=config.vocab_filepath) + unit_type=config.unit_type, vocab=config.vocab_filepath) self.vocab_list = self._text_featurizer.vocab_list def ordid2token(self, texts, texts_len): @@ -248,7 +245,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): for text, n in zip(texts, texts_len): n = n.numpy().item() ids = text[:n] - trans.append(self._text_featurizer.defeaturize(ids.numpy().tolist())) + trans.append( + self._text_featurizer.defeaturize(ids.numpy().tolist())) return trans def compute_metrics(self, diff --git a/paddlespeech/s2t/models/ds2/__init__.py b/paddlespeech/s2t/models/ds2/__init__.py index 0a5c50d86..480f6d3af 100644 --- a/paddlespeech/s2t/models/ds2/__init__.py +++ b/paddlespeech/s2t/models/ds2/__init__.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys + from .deepspeech2 import DeepSpeech2InferModel from .deepspeech2 import DeepSpeech2Model from paddlespeech.s2t.utils import dynamic_pip_install -import sys try: import paddlespeech_ctcdecoders diff --git a/paddlespeech/s2t/models/ds2/deepspeech2.py b/paddlespeech/s2t/models/ds2/deepspeech2.py index 3f2c76ced..b7ee80a7d 100644 --- a/paddlespeech/s2t/models/ds2/deepspeech2.py +++ b/paddlespeech/s2t/models/ds2/deepspeech2.py @@ -372,11 +372,15 @@ class DeepSpeech2InferModel(DeepSpeech2Model): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def forward(self, audio_chunk, audio_chunk_lens, chunk_state_h_box=None, + def forward(self, + audio_chunk, + audio_chunk_lens, + chunk_state_h_box=None, chunk_state_c_box=None): if self.encoder.rnn_direction == "forward": eouts_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box = self.encoder( - audio_chunk, audio_chunk_lens, chunk_state_h_box, chunk_state_c_box) + audio_chunk, audio_chunk_lens, chunk_state_h_box, + chunk_state_c_box) probs_chunk = self.decoder.softmax(eouts_chunk) return probs_chunk, eouts_chunk_lens, final_state_h_box, final_state_c_box elif self.encoder.rnn_direction == "bidirect": @@ -392,8 +396,8 @@ class DeepSpeech2InferModel(DeepSpeech2Model): self, input_spec=[ paddle.static.InputSpec( - shape=[None, None, - self.encoder.feat_size], #[B, chunk_size, feat_dim] + shape=[None, None, self.encoder.feat_size + ], #[B, chunk_size, feat_dim] dtype='float32'), paddle.static.InputSpec(shape=[None], dtype='int64'), # audio_length, [B] diff --git a/paddlespeech/s2t/models/lm/transformer.py b/paddlespeech/s2t/models/lm/transformer.py index d14f99563..04ddddf86 100644 --- a/paddlespeech/s2t/models/lm/transformer.py +++ b/paddlespeech/s2t/models/lm/transformer.py @@ -90,7 +90,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): def _target_mask(self, ys_in_pad): ys_mask = ys_in_pad != 0 - m = subsequent_mask(paddle.shape(ys_mask)[-1])).unsqueeze(0) + m = subsequent_mask(paddle.shape(ys_mask)[-1]).unsqueeze(0) return ys_mask.unsqueeze(-2) & m def forward(self, x: paddle.Tensor, t: paddle.Tensor diff --git a/paddlespeech/s2t/models/u2/updater.py b/paddlespeech/s2t/models/u2/updater.py index 898a50bf0..bb18fe416 100644 --- a/paddlespeech/s2t/models/u2/updater.py +++ b/paddlespeech/s2t/models/u2/updater.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from contextlib import nullcontext import paddle diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index ca576eef1..0f50db21d 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys from typing import Union import paddle @@ -22,7 +23,6 @@ from paddlespeech.s2t.modules.align import Linear from paddlespeech.s2t.modules.loss import CTCLoss from paddlespeech.s2t.utils import ctc_utils from paddlespeech.s2t.utils.log import Log -import sys logger = Log(__name__).getlog() diff --git a/paddlespeech/s2t/utils/tensor_utils.py b/paddlespeech/s2t/utils/tensor_utils.py index bc557b130..f9a843ea1 100644 --- a/paddlespeech/s2t/utils/tensor_utils.py +++ b/paddlespeech/s2t/utils/tensor_utils.py @@ -82,7 +82,8 @@ def pad_sequence(sequences: List[paddle.Tensor], max_size = paddle.shape(sequences[0]) # (TODO Hui Zhang): slice not supprot `end==start` # trailing_dims = max_size[1:] - trailing_dims = tuple(max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else () + trailing_dims = tuple( + max_size[1:].numpy().tolist()) if sequences[0].ndim >= 2 else () max_len = max([s.shape[0] for s in sequences]) if batch_first: out_dims = (len(sequences), max_len) + trailing_dims diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index dee8d78ba..d6f5a227c 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -29,6 +29,7 @@ asr_online: cfg_path: decode_method: force_yes: True + device: # cpu or gpu:id am_predictor_conf: device: # set 'gpu:id' or 'cpu' diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml index 9c0425345..dd5e67ca3 100644 --- a/paddlespeech/server/conf/ws_conformer_application.yaml +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -30,6 +30,8 @@ asr_online: decode_method: force_yes: True device: # cpu or gpu:id + continuous_decoding: True # enable continue decoding when endpoint detected + am_predictor_conf: device: # set 'gpu:id' or 'cpu' switch_ir_optim: True diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index a52219730..8fc210e5a 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -13,6 +13,7 @@ # limitations under the License. import os import sys +from typing import ByteString from typing import Optional import numpy as np @@ -30,9 +31,10 @@ from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoingOpt +from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoint from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch from paddlespeech.server.engine.base_engine import BaseEngine -from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor __all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine'] @@ -54,24 +56,35 @@ class PaddleASRConnectionHanddler: self.model_config = asr_engine.executor.config self.asr_engine = asr_engine - self.init() - self.reset() - - def init(self): # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer self.model_type = self.asr_engine.executor.model_type self.sample_rate = self.asr_engine.executor.sample_rate # tokens to text self.text_feature = self.asr_engine.executor.text_feature + # extract feat, new only fbank in conformer model + self.preprocess_conf = self.model_config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + + # frame window and frame shift, in samples unit + self.win_length = self.preprocess_conf.process[0]['win_length'] + self.n_shift = self.preprocess_conf.process[0]['n_shift'] + + assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, ( + self.sample_rate, self.preprocess_conf.process[0]['fs']) + self.frame_shift_in_ms = int( + self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000) + + self.continuous_decoding = self.config.get("continuous_decoding", False) + self.init_decoder() + self.reset() + + def init_decoder(self): if "deepspeech2" in self.model_type: + assert self.continuous_decoding is False, "ds2 model not support endpoint" self.am_predictor = self.asr_engine.executor.am_predictor - # extract feat, new only fbank in conformer model - self.preprocess_conf = self.model_config.preprocess_config - self.preprocess_args = {"train": False} - self.preprocessing = Transformation(self.preprocess_conf) - self.decoder = CTCDecoder( odim=self.model_config.output_dim, # is in vocab enc_n_units=self.model_config.rnn_layer_size * 2, @@ -90,142 +103,65 @@ class PaddleASRConnectionHanddler: cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) - # frame window and frame shift, in samples unit - self.win_length = self.preprocess_conf.process[0]['win_length'] - self.n_shift = self.preprocess_conf.process[0]['n_shift'] - elif "conformer" in self.model_type or "transformer" in self.model_type: # acoustic model self.model = self.asr_engine.executor.model + self.continuous_decoding = self.config.continuous_decoding + logger.info(f"continue decoding: {self.continuous_decoding}") # ctc decoding config self.ctc_decode_config = self.asr_engine.executor.config.decode self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) - # extract feat, new only fbank in conformer model - self.preprocess_conf = self.model_config.preprocess_config - self.preprocess_args = {"train": False} - self.preprocessing = Transformation(self.preprocess_conf) - - # frame window and frame shift, in samples unit - self.win_length = self.preprocess_conf.process[0]['win_length'] - self.n_shift = self.preprocess_conf.process[0]['n_shift'] + # ctc endpoint + self.endpoint_opt = OnlineCTCEndpoingOpt( + frame_shift_in_ms=self.frame_shift_in_ms, blank=0) + self.endpointer = OnlineCTCEndpoint(self.endpoint_opt) else: raise ValueError(f"Not supported: {self.model_type}") - def extract_feat(self, samples): - # we compute the elapsed time of first char occuring - # and we record the start time at the first pcm sample arraving - - if "deepspeech2online" in self.model_type: - # self.reamined_wav stores all the samples, - # include the original remained_wav and this package samples - samples = np.frombuffer(samples, dtype=np.int16) - assert samples.ndim == 1 - - if self.remained_wav is None: - self.remained_wav = samples - else: - assert self.remained_wav.ndim == 1 - self.remained_wav = np.concatenate([self.remained_wav, samples]) - logger.info( - f"The connection remain the audio samples: {self.remained_wav.shape}" - ) - - # fbank - feat = self.preprocessing(self.remained_wav, - **self.preprocess_args) - feat = paddle.to_tensor( - feat, dtype="float32").unsqueeze(axis=0) - - if self.cached_feat is None: - self.cached_feat = feat - else: - assert (len(feat.shape) == 3) - assert (len(self.cached_feat.shape) == 3) - self.cached_feat = paddle.concat( - [self.cached_feat, feat], axis=1) - - # set the feat device - if self.device is None: - self.device = self.cached_feat.place - - # cur frame step - num_frames = feat.shape[1] - - self.num_frames += num_frames - self.remained_wav = self.remained_wav[self.n_shift * num_frames:] - - logger.info( - f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" - ) - logger.info( - f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" - ) - - elif "conformer_online" in self.model_type: - logger.info("Online ASR extract the feat") - samples = np.frombuffer(samples, dtype=np.int16) - assert samples.ndim == 1 - - self.num_samples += samples.shape[0] - logger.info( - f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}" - ) - - # self.reamined_wav stores all the samples, - # include the original remained_wav and this package samples - if self.remained_wav is None: - self.remained_wav = samples - else: - assert self.remained_wav.ndim == 1 # (T,) - self.remained_wav = np.concatenate([self.remained_wav, samples]) - logger.info( - f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}" - ) - - if len(self.remained_wav) < self.win_length: - # samples not enough for feature window - return 0 + def model_reset(self): + if "deepspeech2" in self.model_type: + return - # fbank - x_chunk = self.preprocessing(self.remained_wav, - **self.preprocess_args) - x_chunk = paddle.to_tensor( - x_chunk, dtype="float32").unsqueeze(axis=0) + # cache for audio and feat + self.remained_wav = None + self.cached_feat = None - # feature cache - if self.cached_feat is None: - self.cached_feat = x_chunk - else: - assert (len(x_chunk.shape) == 3) # (B,T,D) - assert (len(self.cached_feat.shape) == 3) # (B,T,D) - self.cached_feat = paddle.concat( - [self.cached_feat, x_chunk], axis=1) + ## conformer + # cache for conformer online + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_out = None + # conformer decoding state + self.offset = 0 # global offset in decoding frame unit - # set the feat device - if self.device is None: - self.device = self.cached_feat.place + ## just for record info + self.chunk_num = 0 # global decoding chunk num, not used - # cur frame step - num_frames = x_chunk.shape[1] + def output_reset(self): + ## outputs + # partial/ending decoding results + self.result_transcripts = [''] + # token timestamp result + self.word_time_stamp = [] - # global frame step - self.num_frames += num_frames + ## just for record + self.hyps = [] - # update remained wav - self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + # one best timestamp viterbi prob is large. + self.time_stamp = [] - logger.info( - f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" - ) - logger.info( - f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}" - ) - logger.info(f"global samples: {self.num_samples}") - logger.info(f"global frames: {self.num_frames}") - else: - raise ValueError(f"not supported: {self.model_type}") + def reset_continuous_decoding(self): + """ + when in continous decoding, reset for next utterance. + """ + self.global_frame_offset = self.num_frames + self.model_reset() + self.searcher.reset() + self.endpointer.reset() + self.output_reset() def reset(self): if "deepspeech2" in self.model_type: @@ -241,38 +177,87 @@ class PaddleASRConnectionHanddler: dtype=float32) self.decoder.reset_decoder(batch_size=1) + if "conformer" in self.model_type or "transformer" in self.model_type: + self.searcher.reset() + self.endpointer.reset() + self.device = None ## common - # global sample and frame step self.num_samples = 0 + self.global_frame_offset = 0 + # frame step of cur utterance self.num_frames = 0 - # cache for audio and feat - self.remained_wav = None - self.cached_feat = None - - # partial/ending decoding results - self.result_transcripts = [''] + ## endpoint + self.endpoint_state = False # True for detect endpoint ## conformer + self.model_reset() - # cache for conformer online - self.subsampling_cache = None - self.elayers_output_cache = None - self.conformer_cnn_cache = None - self.encoder_out = None - # conformer decoding state - self.chunk_num = 0 # globa decoding chunk num - self.offset = 0 # global offset in decoding frame unit - self.hyps = [] + ## outputs + self.output_reset() - # token timestamp result - self.word_time_stamp = [] + def extract_feat(self, samples: ByteString): + logger.info("Online ASR extract the feat") + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 - # one best timestamp viterbi prob is large. - self.time_stamp = [] + self.num_samples += samples.shape[0] + logger.info( + f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}" + ) + + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 # (T,) + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}" + ) + + if len(self.remained_wav) < self.win_length: + # samples not enough for feature window + return 0 + + # fbank + x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args) + x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0) + + # feature cache + if self.cached_feat is None: + self.cached_feat = x_chunk + else: + assert (len(x_chunk.shape) == 3) # (B,T,D) + assert (len(self.cached_feat.shape) == 3) # (B,T,D) + self.cached_feat = paddle.concat( + [self.cached_feat, x_chunk], axis=1) + + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + + # cur frame step + num_frames = x_chunk.shape[1] + + # global frame step + self.num_frames += num_frames + + # update remained wav + self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + + logger.info( + f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}" + ) + logger.info(f"global samples: {self.num_samples}") + logger.info(f"global frames: {self.num_frames}") def decode(self, is_finished=False): """advance decoding @@ -280,14 +265,12 @@ class PaddleASRConnectionHanddler: Args: is_finished (bool, optional): Is last frame or not. Defaults to False. - Raises: - Exception: when not support model. - Returns: - None: nothing + None: """ - if "deepspeech2online" in self.model_type: + if "deepspeech2" in self.model_type: decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit + context = 7 # context=7, in audio frame unit subsampling = 4 # subsampling=4, in audio frame unit @@ -332,9 +315,11 @@ class PaddleASRConnectionHanddler: end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) + # extract the audio x_chunk = self.cached_feat[:, cur:end, :].numpy() x_chunk_lens = np.array([x_chunk.shape[1]]) + trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) self.result_transcripts = [trans_best] @@ -409,31 +394,41 @@ class PaddleASRConnectionHanddler: @paddle.no_grad() def advance_decoding(self, is_finished=False): + if "deepspeech" in self.model_type: + return + + # reset endpiont state + self.endpoint_state = False + logger.info( "Conformer/Transformer: start to decode with advanced_decoding method" ) cfg = self.ctc_decode_config - # cur chunk size, in decoding frame unit + # cur chunk size, in decoding frame unit, e.g. 16 decoding_chunk_size = cfg.decoding_chunk_size - # using num of history chunks + # using num of history chunks, e.g -1 num_decoding_left_chunks = cfg.num_decoding_left_chunks assert decoding_chunk_size > 0 + # e.g. 4 subsampling = self.model.encoder.embed.subsampling_rate + # e.g. 7 context = self.model.encoder.embed.right_context + 1 - # processed chunk feature cached for next chunk + # processed chunk feature cached for next chunk, e.g. 3 cached_feature_num = context - subsampling - # decoding stride, in audio frame unit - stride = subsampling * decoding_chunk_size + # decoding window, in audio frame unit decoding_window = (decoding_chunk_size - 1) * subsampling + context + # decoding stride, in audio frame unit + stride = subsampling * decoding_chunk_size if self.cached_feat is None: logger.info("no audio feat, please input more pcm data") return + # (B=1,T,D) num_frames = self.cached_feat.shape[1] logger.info( f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" @@ -454,9 +449,6 @@ class PaddleASRConnectionHanddler: return None, None logger.info("start to do model forward") - # hist of chunks, in deocding frame unit - required_cache_size = decoding_chunk_size * num_decoding_left_chunks - outputs = [] # num_frames - context + 1 ensure that current frame can get context window if is_finished: @@ -466,7 +458,11 @@ class PaddleASRConnectionHanddler: # we only process decoding_window frames for one chunk left_frames = decoding_window + # hist of chunks, in deocding frame unit + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + # record the end for removing the processed feat + outputs = [] end = None for cur in range(0, num_frames - left_frames + 1, stride): end = min(cur + decoding_window, num_frames) @@ -491,30 +487,40 @@ class PaddleASRConnectionHanddler: self.encoder_out = ys else: self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) + logger.info( + f"This connection handler encoder out shape: {self.encoder_out.shape}" + ) # get the ctc probs ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) + ## decoding # advance decoding self.searcher.search(ctc_probs, self.cached_feat.place) # get one best hyps self.hyps = self.searcher.get_one_best_hyps() - assert self.cached_feat.shape[0] == 1 - assert end >= cached_feature_num + # endpoint + if not is_finished: + + def contain_nonsilence(): + return len(self.hyps) > 0 and len(self.hyps[0]) > 0 + + decoding_something = contain_nonsilence() + if self.endpointer.endpoint_detected(ctc_probs.numpy(), + decoding_something): + self.endpoint_state = True + logger.info(f"Endpoint is detected at {self.num_frames} frame.") # advance cache of feat - self.cached_feat = self.cached_feat[0, end - - cached_feature_num:, :].unsqueeze(0) + assert self.cached_feat.shape[0] == 1 #(B=1,T,D) + assert end >= cached_feature_num + self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] assert len( self.cached_feat.shape ) == 3, f"current cache feat shape is: {self.cached_feat.shape}" - logger.info( - f"This connection handler encoder out shape: {self.encoder_out.shape}" - ) - def update_result(self): """Conformer/Transformer hyps to result. """ @@ -654,24 +660,28 @@ class PaddleASRConnectionHanddler: # update each word start and end time stamp # decoding frame to audio frame - frame_shift = self.model.encoder.embed.subsampling_rate - frame_shift_in_sec = frame_shift * (self.n_shift / self.sample_rate) - logger.info(f"frame shift sec: {frame_shift_in_sec}") + decode_frame_shift = self.model.encoder.embed.subsampling_rate + decode_frame_shift_in_sec = decode_frame_shift * (self.n_shift / + self.sample_rate) + logger.info(f"decode frame shift in sec: {decode_frame_shift_in_sec}") + + global_offset_in_sec = self.global_frame_offset * self.frame_shift_in_ms / 1000.0 + logger.info(f"global offset: {global_offset_in_sec} sec.") word_time_stamp = [] for idx, _ in enumerate(self.time_stamp): start = (self.time_stamp[idx - 1] + self.time_stamp[idx] ) / 2.0 if idx > 0 else 0 - start = start * frame_shift_in_sec + start = start * decode_frame_shift_in_sec end = (self.time_stamp[idx] + self.time_stamp[idx + 1] ) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset - end = end * frame_shift_in_sec + end = end * decode_frame_shift_in_sec word_time_stamp.append({ "w": self.result_transcripts[0][idx], - "bg": start, - "ed": end + "bg": global_offset_in_sec + start, + "ed": global_offset_in_sec + end }) # logger.info(f"{word_time_stamp[-1]}") @@ -705,13 +715,14 @@ class ASRServerExecutor(ASRExecutor): self.model_type = model_type self.sample_rate = sample_rate + logger.info(f"model_type: {self.model_type}") + sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str self.task_resource.set_task_model(model_tag=tag) + if cfg_path is None or am_model is None or am_params is None: - logger.info(f"Load the pretrained model, tag = {tag}") self.res_path = self.task_resource.res_dir - self.cfg_path = os.path.join( self.res_path, self.task_resource.res_dict['cfg_path']) @@ -719,7 +730,6 @@ class ASRServerExecutor(ASRExecutor): self.task_resource.res_dict['model']) self.am_params = os.path.join(self.res_path, self.task_resource.res_dict['params']) - logger.info(self.res_path) else: self.cfg_path = os.path.abspath(cfg_path) self.am_model = os.path.abspath(am_model) @@ -727,9 +737,12 @@ class ASRServerExecutor(ASRExecutor): self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) - logger.info(self.cfg_path) - logger.info(self.am_model) - logger.info(self.am_params) + logger.info("Load the pretrained model:") + logger.info(f" tag = {tag}") + logger.info(f" res_path: {self.res_path}") + logger.info(f" cfg path: {self.cfg_path}") + logger.info(f" am_model path: {self.am_model}") + logger.info(f" am_params path: {self.am_params}") #Init body. self.config = CfgNode(new_allowed=True) @@ -738,25 +751,39 @@ class ASRServerExecutor(ASRExecutor): if self.config.spm_model_prefix: self.config.spm_model_prefix = os.path.join( self.res_path, self.config.spm_model_prefix) + logger.info(f"spm model path: {self.config.spm_model_prefix}") + + self.vocab = self.config.vocab_filepath + self.text_feature = TextFeaturizer( unit_type=self.config.unit_type, vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) - self.vocab = self.config.vocab_filepath - with UpdateConfig(self.config): - if "deepspeech2" in model_type: + + if "deepspeech2" in model_type: + with UpdateConfig(self.config): + # download lm self.config.decode.lang_model_path = os.path.join( MODEL_HOME, 'language_model', self.config.decode.lang_model_path) - lm_url = self.task_resource.res_dict['lm_url'] - lm_md5 = self.task_resource.res_dict['lm_md5'] - logger.info(f"Start to load language model {lm_url}") - self.download_lm( - lm_url, - os.path.dirname(self.config.decode.lang_model_path), lm_md5) + lm_url = self.task_resource.res_dict['lm_url'] + lm_md5 = self.task_resource.res_dict['lm_md5'] + logger.info(f"Start to load language model {lm_url}") + self.download_lm( + lm_url, + os.path.dirname(self.config.decode.lang_model_path), lm_md5) - elif "conformer" in model_type or "transformer" in model_type: + # AM predictor + logger.info("ASR engine start to init the am predictor") + self.am_predictor_conf = am_predictor_conf + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) + + elif "conformer" in model_type or "transformer" in model_type: + with UpdateConfig(self.config): logger.info("start to create the stream conformer asr engine") # update the decoding method if decode_method: @@ -770,37 +797,24 @@ class ASRServerExecutor(ASRExecutor): logger.info( "we set the decoding_method to attention_rescoring") self.config.decode.decoding_method = "attention_rescoring" + assert self.config.decode.decoding_method in [ "ctc_prefix_beam_search", "attention_rescoring" ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" - else: - raise Exception("wrong type") - if "deepspeech2" in model_type: - # AM predictor - logger.info("ASR engine start to init the am predictor") - self.am_predictor_conf = am_predictor_conf - self.am_predictor = init_predictor( - model_file=self.am_model, - params_file=self.am_params, - predictor_conf=self.am_predictor_conf) - elif "conformer" in model_type or "transformer" in model_type: + # load model model_name = model_type[:model_type.rindex( '_')] # model_type: {model_name}_{dataset} logger.info(f"model name: {model_name}") model_class = self.task_resource.get_model_class(model_name) - model_conf = self.config - model = model_class.from_config(model_conf) + model = model_class.from_config(self.config) self.model = model + self.model.set_state_dict(paddle.load(self.am_model)) self.model.eval() - - # load model - model_dict = paddle.load(self.am_model) - self.model.set_state_dict(model_dict) - logger.info("create the transformer like model success") else: - raise ValueError(f"Not support: {model_type}") + raise Exception(f"not support: {model_type}") + logger.info(f"create the {model_type} model success") return True @@ -857,6 +871,14 @@ class ASREngine(BaseEngine): logger.info("Initialize ASR server engine successfully.") return True + def new_handler(self): + """New handler from model. + + Returns: + PaddleASRConnectionHanddler: asr handler instance + """ + return PaddleASRConnectionHanddler(self) + def preprocess(self, *args, **kwargs): raise NotImplementedError("Online not using this.") diff --git a/paddlespeech/server/engine/asr/online/ctc_endpoint.py b/paddlespeech/server/engine/asr/online/ctc_endpoint.py new file mode 100644 index 000000000..2dba36417 --- /dev/null +++ b/paddlespeech/server/engine/asr/online/ctc_endpoint.py @@ -0,0 +1,118 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +import numpy as np + +from paddlespeech.cli.log import logger + + +@dataclass +class OnlineCTCEndpointRule: + must_contain_nonsilence: bool = True + min_trailing_silence: int = 1000 + min_utterance_length: int = 0 + + +@dataclass +class OnlineCTCEndpoingOpt: + frame_shift_in_ms: int = 10 + + blank: int = 0 # blank id, that we consider as silence for purposes of endpointing. + blank_threshold: float = 0.8 # above blank threshold is silence + + # We support three rules. We terminate decoding if ANY of these rules + # evaluates to "true". If you want to add more rules, do it by changing this + # code. If you want to disable a rule, you can set the silence-timeout for + # that rule to a very large number. + + # rule1 times out after 5 seconds of silence, even if we decoded nothing. + rule1: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 5000, 0) + # rule4 times out after 1.0 seconds of silence after decoding something, + # even if we did not reach a final-state at all. + rule2: OnlineCTCEndpointRule = OnlineCTCEndpointRule(True, 1000, 0) + # rule5 times out after the utterance is 20 seconds long, regardless of + # anything else. + rule3: OnlineCTCEndpointRule = OnlineCTCEndpointRule(False, 0, 20000) + + +class OnlineCTCEndpoint: + """ + [END-TO-END AUTOMATIC SPEECH RECOGNITION INTEGRATED WITH CTC-BASED VOICE ACTIVITY DETECTION](https://arxiv.org/pdf/2002.00551.pdf) + """ + + def __init__(self, opts: OnlineCTCEndpoingOpt): + self.opts = opts + logger.info(f"Endpont Opts: {opts}") + self.frame_shift_in_ms = opts.frame_shift_in_ms + + self.num_frames_decoded = 0 + self.trailing_silence_frames = 0 + + self.reset() + + def reset(self): + self.num_frames_decoded = 0 + self.trailing_silence_frames = 0 + + def rule_activated(self, + rule: OnlineCTCEndpointRule, + rule_name: str, + decoding_something: bool, + trailine_silence: int, + utterance_length: int) -> bool: + ans = ( + decoding_something or (not rule.must_contain_nonsilence) + ) and trailine_silence >= rule.min_trailing_silence and utterance_length >= rule.min_utterance_length + if (ans): + logger.info(f"Endpoint Rule: {rule_name} activated: {rule}") + return ans + + def endpoint_detected(self, + ctc_log_probs: np.ndarray, + decoding_something: bool) -> bool: + """detect endpoint. + + Args: + ctc_log_probs (np.ndarray): (T, D) + decoding_something (bool): contain nonsilince. + + Returns: + bool: whether endpoint detected. + """ + for logprob in ctc_log_probs: + blank_prob = np.exp(logprob[self.opts.blank]) + + self.num_frames_decoded += 1 + if blank_prob > self.opts.blank_threshold: + self.trailing_silence_frames += 1 + else: + self.trailing_silence_frames = 0 + + assert self.num_frames_decoded >= self.trailing_silence_frames + assert self.frame_shift_in_ms > 0 + + utterance_length = self.num_frames_decoded * self.frame_shift_in_ms + trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms + + if self.rule_activated(self.opts.rule1, 'rule1', decoding_something, + trailing_silence, utterance_length): + return True + if self.rule_activated(self.opts.rule2, 'rule2', decoding_something, + trailing_silence, utterance_length): + return True + if self.rule_activated(self.opts.rule3, 'rule3', decoding_something, + trailing_silence, utterance_length): + return True + return False diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index 4c9ac3acb..46f310c80 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -30,8 +30,29 @@ class CTCPrefixBeamSearch: config (yacs.config.CfgNode): the ctc prefix beam search configuration """ self.config = config + + # beam size + self.first_beam_size = self.config.beam_size + # TODO(support second beam size) + self.second_beam_size = int(self.first_beam_size * 1.0) + logger.info( + f"first and second beam size: {self.first_beam_size}, {self.second_beam_size}" + ) + + # state + self.cur_hyps = None + self.hyps = None + self.abs_time_step = 0 + self.reset() + def reset(self): + """Rest the search cache value + """ + self.cur_hyps = None + self.hyps = None + self.abs_time_step = 0 + @paddle.no_grad() def search(self, ctc_probs, device, blank_id=0): """ctc prefix beam search method decode a chunk feature @@ -47,12 +68,17 @@ class CTCPrefixBeamSearch: """ # decode logger.info("start to ctc prefix search") - + assert len(ctc_probs.shape) == 2 batch_size = 1 - beam_size = self.config.beam_size - maxlen = ctc_probs.shape[0] - assert len(ctc_probs.shape) == 2 + vocab_size = ctc_probs.shape[1] + first_beam_size = min(self.first_beam_size, vocab_size) + second_beam_size = min(self.second_beam_size, vocab_size) + logger.info( + f"effect first and second beam size: {self.first_beam_size}, {self.second_beam_size}" + ) + + maxlen = ctc_probs.shape[0] # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) # 0. blank_ending_score, @@ -75,7 +101,8 @@ class CTCPrefixBeamSearch: # 2.1 First beam prune: select topk best # do token passing process - top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + top_k_logp, top_k_index = logp.topk( + first_beam_size) # (first_beam_size,) for s in top_k_index: s = s.item() ps = logp[s].item() @@ -148,7 +175,7 @@ class CTCPrefixBeamSearch: next_hyps.items(), key=lambda x: log_add([x[1][0], x[1][1]]), reverse=True) - self.cur_hyps = next_hyps[:beam_size] + self.cur_hyps = next_hyps[:second_beam_size] # 2.3 update the absolute time step self.abs_time_step += 1 @@ -163,7 +190,7 @@ class CTCPrefixBeamSearch: """Return the one best result Returns: - list: the one best result + list: the one best result, List[str] """ return [self.hyps[0][0]] @@ -171,17 +198,10 @@ class CTCPrefixBeamSearch: """Return the search hyps Returns: - list: return the search hyps + list: return the search hyps, List[Tuple[str, float, ...]] """ return self.hyps - def reset(self): - """Rest the search cache value - """ - self.cur_hyps = None - self.hyps = None - self.abs_time_step = 0 - def finalize_search(self): """do nothing in ctc_prefix_beam_search """ diff --git a/paddlespeech/server/engine/tts/online/python/tts_engine.py b/paddlespeech/server/engine/tts/online/python/tts_engine.py index 6c309f026..2e8997e0f 100644 --- a/paddlespeech/server/engine/tts/online/python/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/python/tts_engine.py @@ -42,7 +42,6 @@ class TTSServerExecutor(TTSExecutor): self.task_resource = CommonTaskResource( task='tts', model_format='dynamic', inference_mode='online') - def get_model_info(self, field: str, model_name: str, diff --git a/paddlespeech/server/ws/asr_api.py b/paddlespeech/server/ws/asr_api.py index 0faa131aa..ff4ec7b0b 100644 --- a/paddlespeech/server/ws/asr_api.py +++ b/paddlespeech/server/ws/asr_api.py @@ -19,7 +19,6 @@ from fastapi import WebSocketDisconnect from starlette.websockets import WebSocketState as WebSocketState from paddlespeech.cli.log import logger -from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler from paddlespeech.server.engine.engine_pool import get_engine_pool router = APIRouter() @@ -38,7 +37,7 @@ async def websocket_endpoint(websocket: WebSocket): #2. if we accept the websocket headers, we will get the online asr engine instance engine_pool = get_engine_pool() - asr_engine = engine_pool['asr'] + asr_model = engine_pool['asr'] #3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio # and each connection has its own connection instance to process the request @@ -70,7 +69,8 @@ async def websocket_endpoint(websocket: WebSocket): resp = {"status": "ok", "signal": "server_ready"} # do something at begining here # create the instance to process the audio - connection_handler = PaddleASRConnectionHanddler(asr_engine) + #connection_handler = PaddleASRConnectionHanddler(asr_model) + connection_handler = asr_model.new_handler() await websocket.send_json(resp) elif message['signal'] == 'end': # reset single engine for an new connection @@ -100,11 +100,34 @@ async def websocket_endpoint(websocket: WebSocket): # and decode for the result in this package data connection_handler.extract_feat(message) connection_handler.decode(is_finished=False) + + if connection_handler.endpoint_state: + logger.info("endpoint: detected and rescoring.") + connection_handler.rescoring() + word_time_stamp = connection_handler.get_word_time_stamp() + asr_results = connection_handler.get_result() - # return the current period result - # if the engine create the vad instance, this connection will have many period results + if connection_handler.endpoint_state: + if connection_handler.continuous_decoding: + logger.info("endpoint: continue decoding") + connection_handler.reset_continuous_decoding() + else: + logger.info("endpoint: exit decoding") + # ending by endpoint + resp = { + "status": "ok", + "signal": "finished", + 'result': asr_results, + 'times': word_time_stamp + } + await websocket.send_json(resp) + break + + # return the current partial result + # if the engine create the vad instance, this connection will have many partial results resp = {'result': asr_results} await websocket.send_json(resp) + except WebSocketDisconnect as e: logger.error(e) diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py index 2b543ef7d..9ddab726e 100644 --- a/paddlespeech/t2s/exps/synthesize.py +++ b/paddlespeech/t2s/exps/synthesize.py @@ -140,10 +140,7 @@ def parse_args(): ], help='Choose acoustic model type of tts task.') parser.add_argument( - '--am_config', - type=str, - default=None, - help='Config of acoustic model.') + '--am_config', type=str, default=None, help='Config of acoustic model.') parser.add_argument( '--am_ckpt', type=str, @@ -179,10 +176,7 @@ def parse_args(): ], help='Choose vocoder type of tts task.') parser.add_argument( - '--voc_config', - type=str, - default=None, - help='Config of voc.') + '--voc_config', type=str, default=None, help='Config of voc.') parser.add_argument( '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') parser.add_argument( diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 6101d1593..28657eb27 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -174,10 +174,7 @@ def parse_args(): ], help='Choose acoustic model type of tts task.') parser.add_argument( - '--am_config', - type=str, - default=None, - help='Config of acoustic model.') + '--am_config', type=str, default=None, help='Config of acoustic model.') parser.add_argument( '--am_ckpt', type=str, @@ -220,10 +217,7 @@ def parse_args(): ], help='Choose vocoder type of tts task.') parser.add_argument( - '--voc_config', - type=str, - default=None, - help='Config of voc.') + '--voc_config', type=str, default=None, help='Config of voc.') parser.add_argument( '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') parser.add_argument( diff --git a/paddlespeech/t2s/exps/voice_cloning.py b/paddlespeech/t2s/exps/voice_cloning.py index 4858d2d56..b51a4d7bc 100644 --- a/paddlespeech/t2s/exps/voice_cloning.py +++ b/paddlespeech/t2s/exps/voice_cloning.py @@ -131,10 +131,7 @@ def parse_args(): choices=['fastspeech2_aishell3', 'tacotron2_aishell3'], help='Choose acoustic model type of tts task.') parser.add_argument( - '--am_config', - type=str, - default=None, - help='Config of acoustic model.') + '--am_config', type=str, default=None, help='Config of acoustic model.') parser.add_argument( '--am_ckpt', type=str, @@ -160,10 +157,7 @@ def parse_args(): help='Choose vocoder type of tts task.') parser.add_argument( - '--voc_config', - type=str, - default=None, - help='Config of voc.') + '--voc_config', type=str, default=None, help='Config of voc.') parser.add_argument( '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') parser.add_argument( diff --git a/paddlespeech/t2s/models/vits/__init__.py b/paddlespeech/t2s/models/vits/__init__.py index 2c23aa3ec..ea43028ae 100644 --- a/paddlespeech/t2s/models/vits/__init__.py +++ b/paddlespeech/t2s/models/vits/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from .vits import * -from .vits_updater import * \ No newline at end of file +from .vits_updater import * diff --git a/paddlespeech/t2s/models/vits/vits_updater.py b/paddlespeech/t2s/models/vits/vits_updater.py index a031dc575..76271fd97 100644 --- a/paddlespeech/t2s/models/vits/vits_updater.py +++ b/paddlespeech/t2s/models/vits/vits_updater.py @@ -56,7 +56,8 @@ class VITSUpdater(StandardUpdater): self.models: Dict[str, Layer] = models # self.model = model - self.model = model._layers if isinstance(model, paddle.DataParallel) else model + self.model = model._layers if isinstance(model, + paddle.DataParallel) else model self.optimizers = optimizers self.optimizer_g: Optimizer = optimizers['generator'] @@ -225,7 +226,8 @@ class VITSEvaluator(StandardEvaluator): models = {"main": model} self.models: Dict[str, Layer] = models # self.model = model - self.model = model._layers if isinstance(model, paddle.DataParallel) else model + self.model = model._layers if isinstance(model, + paddle.DataParallel) else model self.criterions = criterions self.criterion_mel = criterions['mel'] diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index 64678341c..e6ab93513 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -971,18 +971,18 @@ class FeatureMatchLoss(nn.Layer): return feat_match_loss + # loss for VITS class KLDivergenceLoss(nn.Layer): """KL divergence loss.""" def forward( - self, - z_p: paddle.Tensor, - logs_q: paddle.Tensor, - m_p: paddle.Tensor, - logs_p: paddle.Tensor, - z_mask: paddle.Tensor, - ) -> paddle.Tensor: + self, + z_p: paddle.Tensor, + logs_q: paddle.Tensor, + m_p: paddle.Tensor, + logs_p: paddle.Tensor, + z_mask: paddle.Tensor, ) -> paddle.Tensor: """Calculate KL divergence loss. Args: @@ -1002,8 +1002,8 @@ class KLDivergenceLoss(nn.Layer): logs_p = paddle.cast(logs_p, 'float32') z_mask = paddle.cast(z_mask, 'float32') kl = logs_p - logs_q - 0.5 - kl += 0.5 * ((z_p - m_p) ** 2) * paddle.exp(-2.0 * logs_p) + kl += 0.5 * ((z_p - m_p)**2) * paddle.exp(-2.0 * logs_p) kl = paddle.sum(kl * z_mask) loss = kl / paddle.sum(z_mask) - return loss \ No newline at end of file + return loss diff --git a/speechx/examples/README.md b/speechx/examples/README.md index 1b977523c..f7f6f9ac0 100644 --- a/speechx/examples/README.md +++ b/speechx/examples/README.md @@ -25,4 +25,3 @@ netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host > Reminder: Only for developer, make sure you know what's it. * codelab - for speechx developer, using for test. - diff --git a/speechx/examples/ds2_ol/README.md b/speechx/examples/ds2_ol/README.md index 54f3abda4..492d0e1ac 100644 --- a/speechx/examples/ds2_ol/README.md +++ b/speechx/examples/ds2_ol/README.md @@ -3,4 +3,4 @@ ## Examples * `websocket` - Streaming ASR with websocket for deepspeech2_aishell. -* `aishell` - Streaming Decoding under aishell dataset, for local WER test. \ No newline at end of file +* `aishell` - Streaming Decoding under aishell dataset, for local WER test. diff --git a/speechx/speechx/codelab/README.md b/speechx/speechx/codelab/README.md index aee60de67..077c4cef2 100644 --- a/speechx/speechx/codelab/README.md +++ b/speechx/speechx/codelab/README.md @@ -4,4 +4,3 @@ > Reminder: Only for developer. * codelab - for speechx developer, using for test. - diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc index f452636fb..7cfee06c9 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -91,8 +91,8 @@ int main(int argc, char* argv[]) { std::shared_ptr decodable( new ppspeech::Decodable(nnet, raw_data)); - int32 chunk_size = FLAGS_receptive_field_length - + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; + int32 chunk_size = FLAGS_receptive_field_length + + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 receptive_field_length = FLAGS_receptive_field_length; LOG(INFO) << "chunk size (frame): " << chunk_size; diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc index 3f8bdd5a7..712d27dd4 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -64,7 +64,7 @@ std::string TLGDecoder::GetPartialResult() { std::string word = word_symbol_table_->Find(words_id[idx]); words += word; } - return words; + return words; } std::string TLGDecoder::GetFinalBestPath() { diff --git a/speechx/speechx/decoder/param.h b/speechx/speechx/decoder/param.h index 4c5abfad7..d6ee27058 100644 --- a/speechx/speechx/decoder/param.h +++ b/speechx/speechx/decoder/param.h @@ -82,7 +82,7 @@ FeaturePipelineOptions InitFeaturePipelineOptions() { opts.assembler_opts.subsampling_rate = FLAGS_downsampling_rate; opts.assembler_opts.receptive_filed_length = FLAGS_receptive_field_length; opts.assembler_opts.nnet_decoder_chunk = FLAGS_nnet_decoder_chunk; - + return opts; } diff --git a/speechx/speechx/decoder/tlg_decoder_main.cc b/speechx/speechx/decoder/tlg_decoder_main.cc index 010acccf1..b175ed135 100644 --- a/speechx/speechx/decoder/tlg_decoder_main.cc +++ b/speechx/speechx/decoder/tlg_decoder_main.cc @@ -93,8 +93,8 @@ int main(int argc, char* argv[]) { std::shared_ptr decodable( new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); - int32 chunk_size = FLAGS_receptive_field_length - + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; + int32 chunk_size = FLAGS_receptive_field_length + + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 receptive_field_length = FLAGS_receptive_field_length; LOG(INFO) << "chunk size (frame): " << chunk_size; diff --git a/speechx/speechx/frontend/audio/assembler.cc b/speechx/speechx/frontend/audio/assembler.cc index 721d2ad0d..37eeec80f 100644 --- a/speechx/speechx/frontend/audio/assembler.cc +++ b/speechx/speechx/frontend/audio/assembler.cc @@ -24,7 +24,8 @@ using std::unique_ptr; Assembler::Assembler(AssemblerOptions opts, unique_ptr base_extractor) { frame_chunk_stride_ = opts.subsampling_rate * opts.nnet_decoder_chunk; - frame_chunk_size_ = (opts.nnet_decoder_chunk - 1) * opts.subsampling_rate + opts.receptive_filed_length; + frame_chunk_size_ = (opts.nnet_decoder_chunk - 1) * opts.subsampling_rate + + opts.receptive_filed_length; receptive_filed_length_ = opts.receptive_filed_length; base_extractor_ = std::move(base_extractor); dim_ = base_extractor_->Dim(); @@ -50,8 +51,8 @@ bool Assembler::Compute(Vector* feats) { Vector feature; result = base_extractor_->Read(&feature); if (result == false || feature.Dim() == 0) { - if (IsFinished() == false) return false; - break; + if (IsFinished() == false) return false; + break; } feature_cache_.push(feature); } @@ -61,22 +62,22 @@ bool Assembler::Compute(Vector* feats) { } while (feature_cache_.size() < frame_chunk_size_) { - Vector feature(dim_, kaldi::kSetZero); - feature_cache_.push(feature); + Vector feature(dim_, kaldi::kSetZero); + feature_cache_.push(feature); } - int32 counter = 0; + int32 counter = 0; int32 cache_size = frame_chunk_size_ - frame_chunk_stride_; int32 elem_dim = base_extractor_->Dim(); while (counter < frame_chunk_size_) { - Vector& val = feature_cache_.front(); - int32 start = counter * elem_dim; - feats->Range(start, elem_dim).CopyFromVec(val); - if (frame_chunk_size_ - counter <= cache_size ) { - feature_cache_.push(val); - } - feature_cache_.pop(); - counter++; + Vector& val = feature_cache_.front(); + int32 start = counter * elem_dim; + feats->Range(start, elem_dim).CopyFromVec(val); + if (frame_chunk_size_ - counter <= cache_size) { + feature_cache_.push(val); + } + feature_cache_.pop(); + counter++; } return result; diff --git a/speechx/speechx/frontend/audio/assembler.h b/speechx/speechx/frontend/audio/assembler.h index a477df99f..258e61f2b 100644 --- a/speechx/speechx/frontend/audio/assembler.h +++ b/speechx/speechx/frontend/audio/assembler.h @@ -25,7 +25,7 @@ struct AssemblerOptions { int32 receptive_filed_length; int32 subsampling_rate; int32 nnet_decoder_chunk; - + AssemblerOptions() : receptive_filed_length(1), subsampling_rate(1), @@ -47,15 +47,11 @@ class Assembler : public FrontendInterface { // feat dim virtual size_t Dim() const { return dim_; } - virtual void SetFinished() { - base_extractor_->SetFinished(); - } + virtual void SetFinished() { base_extractor_->SetFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); } - virtual void Reset() { - base_extractor_->Reset(); - } + virtual void Reset() { base_extractor_->Reset(); } private: bool Compute(kaldi::Vector* feats); diff --git a/speechx/speechx/frontend/audio/audio_cache.h b/speechx/speechx/frontend/audio/audio_cache.h index e31a8aeb5..fc07d4bab 100644 --- a/speechx/speechx/frontend/audio/audio_cache.h +++ b/speechx/speechx/frontend/audio/audio_cache.h @@ -30,7 +30,7 @@ class AudioCache : public FrontendInterface { virtual bool Read(kaldi::Vector* waves); - // the audio dim is 1, one sample, which is useless, + // the audio dim is 1, one sample, which is useless, // so we return size_(cache samples) instead. virtual size_t Dim() const { return size_; } diff --git a/speechx/speechx/frontend/audio/fbank.cc b/speechx/speechx/frontend/audio/fbank.cc index 1f22263a4..059abbbd1 100644 --- a/speechx/speechx/frontend/audio/fbank.cc +++ b/speechx/speechx/frontend/audio/fbank.cc @@ -29,19 +29,19 @@ using kaldi::Matrix; using std::vector; FbankComputer::FbankComputer(const Options& opts) - : opts_(opts), - computer_(opts) {} + : opts_(opts), computer_(opts) {} int32 FbankComputer::Dim() const { return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); } bool FbankComputer::NeedRawLogEnergy() { - return opts_.use_energy && opts_.raw_energy; + return opts_.use_energy && opts_.raw_energy; } // Compute feat -bool FbankComputer::Compute(Vector* window, Vector* feat) { +bool FbankComputer::Compute(Vector* window, + Vector* feat) { RealFft(window, true); kaldi::ComputePowerSpectrum(window); const kaldi::MelBanks& mel_bank = *(computer_.GetMelBanks(1.0)); diff --git a/speechx/speechx/frontend/audio/feature_cache.cc b/speechx/speechx/frontend/audio/feature_cache.cc index 930f29c54..509a98c3b 100644 --- a/speechx/speechx/frontend/audio/feature_cache.cc +++ b/speechx/speechx/frontend/audio/feature_cache.cc @@ -72,9 +72,9 @@ bool FeatureCache::Compute() { bool result = base_extractor_->Read(&feature); if (result == false || feature.Dim() == 0) return false; - int32 num_chunk = feature.Dim() / dim_ ; + int32 num_chunk = feature.Dim() / dim_; for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { - int32 start = chunk_idx * dim_; + int32 start = chunk_idx * dim_; Vector feature_chunk(dim_); SubVector tmp(feature.Data() + start, dim_); feature_chunk.CopyFromVec(tmp); diff --git a/speechx/speechx/frontend/audio/feature_cache.h b/speechx/speechx/frontend/audio/feature_cache.h index 4c016056a..b922de12c 100644 --- a/speechx/speechx/frontend/audio/feature_cache.h +++ b/speechx/speechx/frontend/audio/feature_cache.h @@ -22,9 +22,7 @@ namespace ppspeech { struct FeatureCacheOptions { int32 max_size; int32 timeout; // ms - FeatureCacheOptions() - : max_size(kint16max), - timeout(1) {} + FeatureCacheOptions() : max_size(kint16max), timeout(1) {} }; class FeatureCache : public FrontendInterface { diff --git a/speechx/speechx/frontend/audio/feature_common.h b/speechx/speechx/frontend/audio/feature_common.h index e03634d35..bad705c9f 100644 --- a/speechx/speechx/frontend/audio/feature_common.h +++ b/speechx/speechx/frontend/audio/feature_common.h @@ -23,11 +23,11 @@ template class StreamingFeatureTpl : public FrontendInterface { public: typedef typename F::Options Options; - StreamingFeatureTpl(const Options& opts, + StreamingFeatureTpl(const Options& opts, std::unique_ptr base_extractor); virtual void Accept(const kaldi::VectorBase& waves); virtual bool Read(kaldi::Vector* feats); - + // the dim_ is the dim of single frame feature virtual size_t Dim() const { return computer_.Dim(); } @@ -39,8 +39,9 @@ class StreamingFeatureTpl : public FrontendInterface { base_extractor_->Reset(); remained_wav_.Resize(0); } + private: - bool Compute(const kaldi::Vector& waves, + bool Compute(const kaldi::Vector& waves, kaldi::Vector* feats); Options opts_; std::unique_ptr base_extractor_; diff --git a/speechx/speechx/frontend/audio/feature_common_inl.h b/speechx/speechx/frontend/audio/feature_common_inl.h index a482ef557..b86f79918 100644 --- a/speechx/speechx/frontend/audio/feature_common_inl.h +++ b/speechx/speechx/frontend/audio/feature_common_inl.h @@ -16,16 +16,15 @@ namespace ppspeech { template -StreamingFeatureTpl::StreamingFeatureTpl(const Options& opts, - std::unique_ptr base_extractor): - opts_(opts), - computer_(opts), - window_function_(opts.frame_opts) { +StreamingFeatureTpl::StreamingFeatureTpl( + const Options& opts, std::unique_ptr base_extractor) + : opts_(opts), computer_(opts), window_function_(opts.frame_opts) { base_extractor_ = std::move(base_extractor); } template -void StreamingFeatureTpl::Accept(const kaldi::VectorBase& waves) { +void StreamingFeatureTpl::Accept( + const kaldi::VectorBase& waves) { base_extractor_->Accept(waves); } @@ -58,8 +57,9 @@ bool StreamingFeatureTpl::Read(kaldi::Vector* feats) { // Compute feat template -bool StreamingFeatureTpl::Compute(const kaldi::Vector& waves, - kaldi::Vector* feats) { +bool StreamingFeatureTpl::Compute( + const kaldi::Vector& waves, + kaldi::Vector* feats) { const kaldi::FrameExtractionOptions& frame_opts = computer_.GetFrameOptions(); int32 num_samples = waves.Dim(); @@ -84,9 +84,11 @@ bool StreamingFeatureTpl::Compute(const kaldi::Vector& wave &window, need_raw_log_energy ? &raw_log_energy : NULL); - kaldi::Vector this_feature(computer_.Dim(), kaldi::kUndefined); + kaldi::Vector this_feature(computer_.Dim(), + kaldi::kUndefined); computer_.Compute(&window, &this_feature); - kaldi::SubVector output_row(feats->Data() + frame * Dim(), Dim()); + kaldi::SubVector output_row( + feats->Data() + frame * Dim(), Dim()); output_row.CopyFromVec(this_feature); } return true; diff --git a/speechx/speechx/frontend/audio/feature_pipeline.h b/speechx/speechx/frontend/audio/feature_pipeline.h index 49b2f2678..48f95e3f3 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.h +++ b/speechx/speechx/frontend/audio/feature_pipeline.h @@ -16,6 +16,7 @@ #pragma once +#include "frontend/audio/assembler.h" #include "frontend/audio/audio_cache.h" #include "frontend/audio/data_cache.h" #include "frontend/audio/fbank.h" @@ -23,7 +24,6 @@ #include "frontend/audio/frontend_itf.h" #include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/normalizer.h" -#include "frontend/audio/assembler.h" namespace ppspeech { diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.cc b/speechx/speechx/frontend/audio/linear_spectrogram.cc index 76580fd51..55c039787 100644 --- a/speechx/speechx/frontend/audio/linear_spectrogram.cc +++ b/speechx/speechx/frontend/audio/linear_spectrogram.cc @@ -28,22 +28,21 @@ using kaldi::VectorBase; using kaldi::Matrix; using std::vector; -LinearSpectrogramComputer::LinearSpectrogramComputer( - const Options& opts) +LinearSpectrogramComputer::LinearSpectrogramComputer(const Options& opts) : opts_(opts) { kaldi::FeatureWindowFunction feature_window_function(opts.frame_opts); int32 window_size = opts.frame_opts.WindowSize(); frame_length_ = window_size; dim_ = window_size / 2 + 1; - BaseFloat hanning_window_energy = kaldi::VecVec(feature_window_function.window, - feature_window_function.window); + BaseFloat hanning_window_energy = kaldi::VecVec( + feature_window_function.window, feature_window_function.window); int32 sample_rate = opts.frame_opts.samp_freq; scale_ = 2.0 / (hanning_window_energy * sample_rate); } // Compute spectrogram feat bool LinearSpectrogramComputer::Compute(Vector* window, - Vector* feat) { + Vector* feat) { window->Resize(frame_length_, kaldi::kCopyData); RealFft(window, true); kaldi::ComputePowerSpectrum(window); diff --git a/speechx/speechx/nnet/nnet_forward_main.cc b/speechx/speechx/nnet/nnet_forward_main.cc index 0c5a55a71..0d4ea8ff7 100644 --- a/speechx/speechx/nnet/nnet_forward_main.cc +++ b/speechx/speechx/nnet/nnet_forward_main.cc @@ -14,8 +14,8 @@ #include "base/flags.h" #include "base/log.h" -#include "frontend/audio/data_cache.h" #include "frontend/audio/assembler.h" +#include "frontend/audio/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/paddle_nnet.h" @@ -75,8 +75,8 @@ int main(int argc, char* argv[]) { std::shared_ptr decodable( new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); - int32 chunk_size = FLAGS_receptive_field_length - + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; + int32 chunk_size = FLAGS_receptive_field_length + + (FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 receptive_field_length = FLAGS_receptive_field_length; LOG(INFO) << "chunk size (frame): " << chunk_size; @@ -130,7 +130,9 @@ int main(int argc, char* argv[]) { vector prob; while (decodable->FrameLikelihood(frame_idx, &prob)) { kaldi::Vector vec_tmp(prob.size()); - std::memcpy(vec_tmp.Data(), prob.data(), sizeof(kaldi::BaseFloat)*prob.size()); + std::memcpy(vec_tmp.Data(), + prob.data(), + sizeof(kaldi::BaseFloat) * prob.size()); prob_vec.push_back(vec_tmp); frame_idx++; } @@ -142,7 +144,8 @@ int main(int argc, char* argv[]) { KALDI_LOG << " the nnet prob of " << utt << " is empty"; continue; } - kaldi::Matrix result(prob_vec.size(),prob_vec[0].Dim()); + kaldi::Matrix result(prob_vec.size(), + prob_vec[0].Dim()); for (int32 row_idx = 0; row_idx < prob_vec.size(); ++row_idx) { for (int32 col_idx = 0; col_idx < prob_vec[0].Dim(); ++col_idx) { result(row_idx, col_idx) = prob_vec[row_idx](col_idx); diff --git a/speechx/speechx/protocol/websocket/websocket_client.h b/speechx/speechx/protocol/websocket/websocket_client.h index 8635501a8..886da2929 100644 --- a/speechx/speechx/protocol/websocket/websocket_client.h +++ b/speechx/speechx/protocol/websocket/websocket_client.h @@ -40,8 +40,8 @@ class WebSocketClient { void SendEndSignal(); void SendDataEnd(); bool Done() const { return done_; } - std::string GetResult() const { return result_; } - std::string GetPartialResult() const { return partial_result_;} + std::string GetResult() const { return result_; } + std::string GetPartialResult() const { return partial_result_; } private: void Connect(); diff --git a/speechx/speechx/protocol/websocket/websocket_server.cc b/speechx/speechx/protocol/websocket/websocket_server.cc index a1abd98e6..14f2f6e9f 100644 --- a/speechx/speechx/protocol/websocket/websocket_server.cc +++ b/speechx/speechx/protocol/websocket/websocket_server.cc @@ -76,9 +76,10 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) { recognizer_->Accept(pcm_data); std::string partial_result = recognizer_->GetPartialResult(); - - json::value rv = { - {"status", "ok"}, {"type", "partial_result"}, {"result", partial_result}}; + + json::value rv = {{"status", "ok"}, + {"type", "partial_result"}, + {"result", partial_result}}; ws_.text(true); ws_.write(asio::buffer(json::serialize(rv))); }