commit
61941d14b0
@ -0,0 +1,51 @@
|
|||||||
|
# This is the parameter configuration file for PaddleSpeech Serving.
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# SERVER SETTING #
|
||||||
|
#################################################################################
|
||||||
|
host: 0.0.0.0
|
||||||
|
port: 8091
|
||||||
|
|
||||||
|
# The task format in the engin_list is: <speech task>_<engine type>
|
||||||
|
# task choices = ['asr_online', 'tts_online']
|
||||||
|
# protocol = ['websocket', 'http'] (only one can be selected).
|
||||||
|
# websocket only support online engine type.
|
||||||
|
protocol: 'websocket'
|
||||||
|
engine_list: ['asr_online']
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# ENGINE CONFIG #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
################################### ASR #########################################
|
||||||
|
################### speech task: asr; engine_type: online #######################
|
||||||
|
asr_online:
|
||||||
|
model_type: 'deepspeech2online_aishell'
|
||||||
|
am_model: # the pdmodel file of am static model [optional]
|
||||||
|
am_params: # the pdiparams file of am static model [optional]
|
||||||
|
lang: 'zh'
|
||||||
|
sample_rate: 16000
|
||||||
|
cfg_path:
|
||||||
|
decode_method:
|
||||||
|
force_yes: True
|
||||||
|
|
||||||
|
am_predictor_conf:
|
||||||
|
device: # set 'gpu:id' or 'cpu'
|
||||||
|
switch_ir_optim: True
|
||||||
|
glog_info: False # True -> print glog
|
||||||
|
summary: True # False -> do not show predictor config
|
||||||
|
|
||||||
|
chunk_buffer_conf:
|
||||||
|
frame_duration_ms: 80
|
||||||
|
shift_ms: 40
|
||||||
|
sample_rate: 16000
|
||||||
|
sample_width: 2
|
||||||
|
|
||||||
|
vad_conf:
|
||||||
|
aggressiveness: 2
|
||||||
|
sample_rate: 16000
|
||||||
|
frame_duration_ms: 20
|
||||||
|
sample_width: 2
|
||||||
|
padding_ms: 200
|
||||||
|
padding_ratio: 0.9
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,355 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
from numpy import float32
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.s2t.frontend.speech import SpeechSegment
|
||||||
|
from paddlespeech.cli.asr.infer import ASRExecutor
|
||||||
|
from paddlespeech.cli.log import logger
|
||||||
|
from paddlespeech.cli.utils import MODEL_HOME
|
||||||
|
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||||
|
from paddlespeech.s2t.modules.ctc import CTCDecoder
|
||||||
|
from paddlespeech.s2t.utils.utility import UpdateConfig
|
||||||
|
from paddlespeech.server.engine.base_engine import BaseEngine
|
||||||
|
from paddlespeech.server.utils.config import get_config
|
||||||
|
from paddlespeech.server.utils.paddle_predictor import init_predictor
|
||||||
|
from paddlespeech.server.utils.paddle_predictor import run_model
|
||||||
|
|
||||||
|
__all__ = ['ASREngine']
|
||||||
|
|
||||||
|
pretrained_models = {
|
||||||
|
"deepspeech2online_aishell-zh-16k": {
|
||||||
|
'url':
|
||||||
|
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz',
|
||||||
|
'md5':
|
||||||
|
'd5e076217cf60486519f72c217d21b9b',
|
||||||
|
'cfg_path':
|
||||||
|
'model.yaml',
|
||||||
|
'ckpt_path':
|
||||||
|
'exp/deepspeech2_online/checkpoints/avg_1',
|
||||||
|
'model':
|
||||||
|
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel',
|
||||||
|
'params':
|
||||||
|
'exp/deepspeech2_online/checkpoints/avg_1.jit.pdiparams',
|
||||||
|
'lm_url':
|
||||||
|
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
|
||||||
|
'lm_md5':
|
||||||
|
'29e02312deb2e59b3c8686c7966d4fe3'
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ASRServerExecutor(ASRExecutor):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _init_from_path(self,
|
||||||
|
model_type: str='wenetspeech',
|
||||||
|
am_model: Optional[os.PathLike]=None,
|
||||||
|
am_params: Optional[os.PathLike]=None,
|
||||||
|
lang: str='zh',
|
||||||
|
sample_rate: int=16000,
|
||||||
|
cfg_path: Optional[os.PathLike]=None,
|
||||||
|
decode_method: str='attention_rescoring',
|
||||||
|
am_predictor_conf: dict=None):
|
||||||
|
"""
|
||||||
|
Init model and other resources from a specific path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if cfg_path is None or am_model is None or am_params is None:
|
||||||
|
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
||||||
|
tag = model_type + '-' + lang + '-' + sample_rate_str
|
||||||
|
res_path = self._get_pretrained_path(tag) # wenetspeech_zh
|
||||||
|
self.res_path = res_path
|
||||||
|
self.cfg_path = os.path.join(res_path,
|
||||||
|
pretrained_models[tag]['cfg_path'])
|
||||||
|
|
||||||
|
self.am_model = os.path.join(res_path,
|
||||||
|
pretrained_models[tag]['model'])
|
||||||
|
self.am_params = os.path.join(res_path,
|
||||||
|
pretrained_models[tag]['params'])
|
||||||
|
logger.info(res_path)
|
||||||
|
logger.info(self.cfg_path)
|
||||||
|
logger.info(self.am_model)
|
||||||
|
logger.info(self.am_params)
|
||||||
|
else:
|
||||||
|
self.cfg_path = os.path.abspath(cfg_path)
|
||||||
|
self.am_model = os.path.abspath(am_model)
|
||||||
|
self.am_params = os.path.abspath(am_params)
|
||||||
|
self.res_path = os.path.dirname(
|
||||||
|
os.path.dirname(os.path.abspath(self.cfg_path)))
|
||||||
|
|
||||||
|
#Init body.
|
||||||
|
self.config = CfgNode(new_allowed=True)
|
||||||
|
self.config.merge_from_file(self.cfg_path)
|
||||||
|
|
||||||
|
with UpdateConfig(self.config):
|
||||||
|
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
|
||||||
|
from paddlespeech.s2t.io.collator import SpeechCollator
|
||||||
|
self.vocab = self.config.vocab_filepath
|
||||||
|
self.config.decode.lang_model_path = os.path.join(
|
||||||
|
MODEL_HOME, 'language_model',
|
||||||
|
self.config.decode.lang_model_path)
|
||||||
|
self.collate_fn_test = SpeechCollator.from_config(self.config)
|
||||||
|
self.text_feature = TextFeaturizer(
|
||||||
|
unit_type=self.config.unit_type, vocab=self.vocab)
|
||||||
|
|
||||||
|
lm_url = pretrained_models[tag]['lm_url']
|
||||||
|
lm_md5 = pretrained_models[tag]['lm_md5']
|
||||||
|
self.download_lm(
|
||||||
|
lm_url,
|
||||||
|
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
|
||||||
|
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
|
||||||
|
raise Exception("wrong type")
|
||||||
|
else:
|
||||||
|
raise Exception("wrong type")
|
||||||
|
|
||||||
|
# AM predictor
|
||||||
|
self.am_predictor_conf = am_predictor_conf
|
||||||
|
self.am_predictor = init_predictor(
|
||||||
|
model_file=self.am_model,
|
||||||
|
params_file=self.am_params,
|
||||||
|
predictor_conf=self.am_predictor_conf)
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
self.decoder = CTCDecoder(
|
||||||
|
odim=self.config.output_dim, # <blank> is in vocab
|
||||||
|
enc_n_units=self.config.rnn_layer_size * 2,
|
||||||
|
blank_id=self.config.blank_id,
|
||||||
|
dropout_rate=0.0,
|
||||||
|
reduction=True, # sum
|
||||||
|
batch_average=True, # sum / batch_size
|
||||||
|
grad_norm_type=self.config.get('ctc_grad_norm_type', None))
|
||||||
|
|
||||||
|
# init decoder
|
||||||
|
cfg = self.config.decode
|
||||||
|
decode_batch_size = 1 # for online
|
||||||
|
self.decoder.init_decoder(
|
||||||
|
decode_batch_size, self.text_feature.vocab_list,
|
||||||
|
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
|
||||||
|
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
|
||||||
|
cfg.num_proc_bsearch)
|
||||||
|
|
||||||
|
# init state box
|
||||||
|
self.chunk_state_h_box = np.zeros(
|
||||||
|
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
|
||||||
|
dtype=float32)
|
||||||
|
self.chunk_state_c_box = np.zeros(
|
||||||
|
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
|
||||||
|
dtype=float32)
|
||||||
|
|
||||||
|
def reset_decoder_and_chunk(self):
|
||||||
|
"""reset decoder and chunk state for an new audio
|
||||||
|
"""
|
||||||
|
self.decoder.reset_decoder(batch_size=1)
|
||||||
|
# init state box, for new audio request
|
||||||
|
self.chunk_state_h_box = np.zeros(
|
||||||
|
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
|
||||||
|
dtype=float32)
|
||||||
|
self.chunk_state_c_box = np.zeros(
|
||||||
|
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
|
||||||
|
dtype=float32)
|
||||||
|
|
||||||
|
def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str):
|
||||||
|
"""decode one chunk
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_chunk (numpy.array): shape[B, T, D]
|
||||||
|
x_chunk_lens (numpy.array): shape[B]
|
||||||
|
model_type (str): online model type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[type]: [description]
|
||||||
|
"""
|
||||||
|
if "deepspeech2online" in model_type :
|
||||||
|
input_names = self.am_predictor.get_input_names()
|
||||||
|
audio_handle = self.am_predictor.get_input_handle(input_names[0])
|
||||||
|
audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
|
||||||
|
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
|
||||||
|
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
|
||||||
|
|
||||||
|
audio_handle.reshape(x_chunk.shape)
|
||||||
|
audio_handle.copy_from_cpu(x_chunk)
|
||||||
|
|
||||||
|
audio_len_handle.reshape(x_chunk_lens.shape)
|
||||||
|
audio_len_handle.copy_from_cpu(x_chunk_lens)
|
||||||
|
|
||||||
|
h_box_handle.reshape(self.chunk_state_h_box.shape)
|
||||||
|
h_box_handle.copy_from_cpu(self.chunk_state_h_box)
|
||||||
|
|
||||||
|
c_box_handle.reshape(self.chunk_state_c_box.shape)
|
||||||
|
c_box_handle.copy_from_cpu(self.chunk_state_c_box)
|
||||||
|
|
||||||
|
output_names = self.am_predictor.get_output_names()
|
||||||
|
output_handle = self.am_predictor.get_output_handle(output_names[0])
|
||||||
|
output_lens_handle = self.am_predictor.get_output_handle(output_names[1])
|
||||||
|
output_state_h_handle = self.am_predictor.get_output_handle(
|
||||||
|
output_names[2])
|
||||||
|
output_state_c_handle = self.am_predictor.get_output_handle(
|
||||||
|
output_names[3])
|
||||||
|
|
||||||
|
self.am_predictor.run()
|
||||||
|
|
||||||
|
output_chunk_probs = output_handle.copy_to_cpu()
|
||||||
|
output_chunk_lens = output_lens_handle.copy_to_cpu()
|
||||||
|
self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
|
||||||
|
self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()
|
||||||
|
|
||||||
|
self.decoder.next(output_chunk_probs, output_chunk_lens)
|
||||||
|
trans_best, trans_beam = self.decoder.decode()
|
||||||
|
|
||||||
|
return trans_best[0]
|
||||||
|
|
||||||
|
elif "conformer" in model_type or "transformer" in model_type:
|
||||||
|
raise Exception("invalid model name")
|
||||||
|
else:
|
||||||
|
raise Exception("invalid model name")
|
||||||
|
|
||||||
|
def _pcm16to32(self, audio):
|
||||||
|
"""pcm int16 to float32
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio(numpy.array): numpy.int16
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
audio(numpy.array): numpy.float32
|
||||||
|
"""
|
||||||
|
if audio.dtype == np.int16:
|
||||||
|
audio = audio.astype("float32")
|
||||||
|
bits = np.iinfo(np.int16).bits
|
||||||
|
audio = audio / (2**(bits - 1))
|
||||||
|
return audio
|
||||||
|
|
||||||
|
def extract_feat(self, samples, sample_rate):
|
||||||
|
"""extract feat
|
||||||
|
|
||||||
|
Args:
|
||||||
|
samples (numpy.array): numpy.float32
|
||||||
|
sample_rate (int): sample rate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x_chunk (numpy.array): shape[B, T, D]
|
||||||
|
x_chunk_lens (numpy.array): shape[B]
|
||||||
|
"""
|
||||||
|
# pcm16 -> pcm 32
|
||||||
|
samples = self._pcm16to32(samples)
|
||||||
|
|
||||||
|
# read audio
|
||||||
|
speech_segment = SpeechSegment.from_pcm(
|
||||||
|
samples, sample_rate, transcript=" ")
|
||||||
|
# audio augment
|
||||||
|
self.collate_fn_test.augmentation.transform_audio(speech_segment)
|
||||||
|
|
||||||
|
# extract speech feature
|
||||||
|
spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
|
||||||
|
speech_segment, self.collate_fn_test.keep_transcription_text)
|
||||||
|
# CMVN spectrum
|
||||||
|
if self.collate_fn_test._normalizer:
|
||||||
|
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
|
||||||
|
|
||||||
|
# spectrum augment
|
||||||
|
audio = self.collate_fn_test.augmentation.transform_feature(spectrum)
|
||||||
|
|
||||||
|
audio_len = audio.shape[0]
|
||||||
|
audio = paddle.to_tensor(audio, dtype='float32')
|
||||||
|
# audio_len = paddle.to_tensor(audio_len)
|
||||||
|
audio = paddle.unsqueeze(audio, axis=0)
|
||||||
|
|
||||||
|
x_chunk = audio.numpy()
|
||||||
|
x_chunk_lens = np.array([audio_len])
|
||||||
|
|
||||||
|
return x_chunk, x_chunk_lens
|
||||||
|
|
||||||
|
|
||||||
|
class ASREngine(BaseEngine):
|
||||||
|
"""ASR server engine
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metaclass: Defaults to Singleton.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(ASREngine, self).__init__()
|
||||||
|
|
||||||
|
def init(self, config: dict) -> bool:
|
||||||
|
"""init engine resource
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file (str): config file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: init failed or success
|
||||||
|
"""
|
||||||
|
self.input = None
|
||||||
|
self.output = ""
|
||||||
|
self.executor = ASRServerExecutor()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.executor._init_from_path(
|
||||||
|
model_type=self.config.model_type,
|
||||||
|
am_model=self.config.am_model,
|
||||||
|
am_params=self.config.am_params,
|
||||||
|
lang=self.config.lang,
|
||||||
|
sample_rate=self.config.sample_rate,
|
||||||
|
cfg_path=self.config.cfg_path,
|
||||||
|
decode_method=self.config.decode_method,
|
||||||
|
am_predictor_conf=self.config.am_predictor_conf)
|
||||||
|
|
||||||
|
logger.info("Initialize ASR server engine successfully.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def preprocess(self, samples, sample_rate):
|
||||||
|
"""preprocess
|
||||||
|
|
||||||
|
Args:
|
||||||
|
samples (numpy.array): numpy.float32
|
||||||
|
sample_rate (int): sample rate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x_chunk (numpy.array): shape[B, T, D]
|
||||||
|
x_chunk_lens (numpy.array): shape[B]
|
||||||
|
"""
|
||||||
|
x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate)
|
||||||
|
return x_chunk, x_chunk_lens
|
||||||
|
|
||||||
|
def run(self, x_chunk, x_chunk_lens, decoder_chunk_size=1):
|
||||||
|
"""run online engine
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_chunk (numpy.array): shape[B, T, D]
|
||||||
|
x_chunk_lens (numpy.array): shape[B]
|
||||||
|
decoder_chunk_size(int)
|
||||||
|
"""
|
||||||
|
self.output = self.executor.decode_one_chunk(x_chunk, x_chunk_lens, self.config.model_type)
|
||||||
|
|
||||||
|
def postprocess(self):
|
||||||
|
"""postprocess
|
||||||
|
"""
|
||||||
|
return self.output
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""reset engine decoder and inference state
|
||||||
|
"""
|
||||||
|
self.executor.reset_decoder_and_chunk()
|
||||||
|
self.output = ""
|
@ -0,0 +1,161 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
record wave from the mic
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import wave
|
||||||
|
from signal import SIGINT
|
||||||
|
from signal import SIGTERM
|
||||||
|
|
||||||
|
import pyaudio
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
|
||||||
|
class ASRAudioHandler(threading.Thread):
|
||||||
|
def __init__(self, url="127.0.0.1", port=8091):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.url = url
|
||||||
|
self.port = port
|
||||||
|
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr"
|
||||||
|
self.fileName = "./output.wav"
|
||||||
|
self.chunk = 5120
|
||||||
|
self.format = pyaudio.paInt16
|
||||||
|
self.channels = 1
|
||||||
|
self.rate = 16000
|
||||||
|
self._running = True
|
||||||
|
self._frames = []
|
||||||
|
self.data_backup = []
|
||||||
|
|
||||||
|
def startrecord(self):
|
||||||
|
"""
|
||||||
|
start a new thread to record wave
|
||||||
|
"""
|
||||||
|
threading._start_new_thread(self.recording, ())
|
||||||
|
|
||||||
|
def recording(self):
|
||||||
|
"""
|
||||||
|
recording wave
|
||||||
|
"""
|
||||||
|
self._running = True
|
||||||
|
self._frames = []
|
||||||
|
p = pyaudio.PyAudio()
|
||||||
|
stream = p.open(
|
||||||
|
format=self.format,
|
||||||
|
channels=self.channels,
|
||||||
|
rate=self.rate,
|
||||||
|
input=True,
|
||||||
|
frames_per_buffer=self.chunk)
|
||||||
|
while (self._running):
|
||||||
|
data = stream.read(self.chunk)
|
||||||
|
self._frames.append(data)
|
||||||
|
self.data_backup.append(data)
|
||||||
|
|
||||||
|
stream.stop_stream()
|
||||||
|
stream.close()
|
||||||
|
p.terminate()
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
"""
|
||||||
|
save wave data
|
||||||
|
"""
|
||||||
|
p = pyaudio.PyAudio()
|
||||||
|
wf = wave.open(self.fileName, 'wb')
|
||||||
|
wf.setnchannels(self.channels)
|
||||||
|
wf.setsampwidth(p.get_sample_size(self.format))
|
||||||
|
wf.setframerate(self.rate)
|
||||||
|
wf.writeframes(b''.join(self.data_backup))
|
||||||
|
wf.close()
|
||||||
|
p.terminate()
|
||||||
|
|
||||||
|
def stoprecord(self):
|
||||||
|
"""
|
||||||
|
stop recording
|
||||||
|
"""
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
aa = input("是否开始录音? (y/n)")
|
||||||
|
if aa.strip() == "y":
|
||||||
|
self.startrecord()
|
||||||
|
logging.info("*" * 10 + "开始录音,请输入语音")
|
||||||
|
|
||||||
|
async with websockets.connect(self.url) as ws:
|
||||||
|
# 发送开始指令
|
||||||
|
audio_info = json.dumps(
|
||||||
|
{
|
||||||
|
"name": "test.wav",
|
||||||
|
"signal": "start",
|
||||||
|
"nbest": 5
|
||||||
|
},
|
||||||
|
sort_keys=True,
|
||||||
|
indent=4,
|
||||||
|
separators=(',', ': '))
|
||||||
|
await ws.send(audio_info)
|
||||||
|
msg = await ws.recv()
|
||||||
|
logging.info("receive msg={}".format(msg))
|
||||||
|
|
||||||
|
# send bytes data
|
||||||
|
logging.info("结束录音请: Ctrl + c。继续请按回车。")
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
while len(self._frames) > 0:
|
||||||
|
await ws.send(self._frames.pop(0))
|
||||||
|
msg = await ws.recv()
|
||||||
|
logging.info("receive msg={}".format(msg))
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# quit
|
||||||
|
# send finished
|
||||||
|
audio_info = json.dumps(
|
||||||
|
{
|
||||||
|
"name": "test.wav",
|
||||||
|
"signal": "end",
|
||||||
|
"nbest": 5
|
||||||
|
},
|
||||||
|
sort_keys=True,
|
||||||
|
indent=4,
|
||||||
|
separators=(',', ': '))
|
||||||
|
await ws.send(audio_info)
|
||||||
|
msg = await ws.recv()
|
||||||
|
logging.info("receive msg={}".format(msg))
|
||||||
|
|
||||||
|
self.stoprecord()
|
||||||
|
logging.info("*" * 10 + "录音结束")
|
||||||
|
self.save()
|
||||||
|
elif aa.strip() == "n":
|
||||||
|
exit()
|
||||||
|
else:
|
||||||
|
print("无效输入!")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logging.info("asr websocket client start")
|
||||||
|
|
||||||
|
handler = ASRAudioHandler("127.0.0.1", 8091)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
main_task = asyncio.ensure_future(handler.run())
|
||||||
|
for signal in [SIGINT, SIGTERM]:
|
||||||
|
loop.add_signal_handler(signal, main_task.cancel)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(main_task)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
logging.info("asr websocket client finished")
|
@ -0,0 +1,115 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#!/usr/bin/python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
|
||||||
|
class ASRAudioHandler:
|
||||||
|
def __init__(self, url="127.0.0.1", port=8090):
|
||||||
|
self.url = url
|
||||||
|
self.port = port
|
||||||
|
self.url = "ws://" + self.url + ":" + str(self.port) + "/ws/asr"
|
||||||
|
|
||||||
|
def read_wave(self, wavfile_path: str):
|
||||||
|
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
|
||||||
|
x_len = len(samples)
|
||||||
|
chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
|
||||||
|
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
|
||||||
|
|
||||||
|
if (x_len - chunk_size) % chunk_stride != 0:
|
||||||
|
padding_len_x = chunk_stride - (x_len - chunk_size) % chunk_stride
|
||||||
|
else:
|
||||||
|
padding_len_x = 0
|
||||||
|
|
||||||
|
padding = np.zeros((padding_len_x), dtype=samples.dtype)
|
||||||
|
padded_x = np.concatenate([samples, padding], axis=0)
|
||||||
|
|
||||||
|
num_chunk = (x_len + padding_len_x - chunk_size) / chunk_stride + 1
|
||||||
|
num_chunk = int(num_chunk)
|
||||||
|
|
||||||
|
for i in range(0, num_chunk):
|
||||||
|
start = i * chunk_stride
|
||||||
|
end = start + chunk_size
|
||||||
|
x_chunk = padded_x[start:end]
|
||||||
|
yield x_chunk
|
||||||
|
|
||||||
|
async def run(self, wavfile_path: str):
|
||||||
|
logging.info("send a message to the server")
|
||||||
|
# 读取音频
|
||||||
|
# self.read_wave()
|
||||||
|
# 发送 websocket 的 handshake 协议头
|
||||||
|
async with websockets.connect(self.url) as ws:
|
||||||
|
# server 端已经接收到 handshake 协议头
|
||||||
|
# 发送开始指令
|
||||||
|
audio_info = json.dumps(
|
||||||
|
{
|
||||||
|
"name": "test.wav",
|
||||||
|
"signal": "start",
|
||||||
|
"nbest": 5
|
||||||
|
},
|
||||||
|
sort_keys=True,
|
||||||
|
indent=4,
|
||||||
|
separators=(',', ': '))
|
||||||
|
await ws.send(audio_info)
|
||||||
|
msg = await ws.recv()
|
||||||
|
logging.info("receive msg={}".format(msg))
|
||||||
|
|
||||||
|
# send chunk audio data to engine
|
||||||
|
for chunk_data in self.read_wave(wavfile_path):
|
||||||
|
await ws.send(chunk_data.tobytes())
|
||||||
|
msg = await ws.recv()
|
||||||
|
logging.info("receive msg={}".format(msg))
|
||||||
|
|
||||||
|
# finished
|
||||||
|
audio_info = json.dumps(
|
||||||
|
{
|
||||||
|
"name": "test.wav",
|
||||||
|
"signal": "end",
|
||||||
|
"nbest": 5
|
||||||
|
},
|
||||||
|
sort_keys=True,
|
||||||
|
indent=4,
|
||||||
|
separators=(',', ': '))
|
||||||
|
await ws.send(audio_info)
|
||||||
|
msg = await ws.recv()
|
||||||
|
logging.info("receive msg={}".format(msg))
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logging.info("asr websocket client start")
|
||||||
|
handler = ASRAudioHandler("127.0.0.1", 8091)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(handler.run(args.wavfile))
|
||||||
|
logging.info("asr websocket client finished")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--wavfile",
|
||||||
|
action="store",
|
||||||
|
help="wav file path ",
|
||||||
|
default="./16_audio.wav")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
class Frame(object):
|
||||||
|
"""Represents a "frame" of audio data."""
|
||||||
|
|
||||||
|
def __init__(self, bytes, timestamp, duration):
|
||||||
|
self.bytes = bytes
|
||||||
|
self.timestamp = timestamp
|
||||||
|
self.duration = duration
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkBuffer(object):
|
||||||
|
def __init__(self,
|
||||||
|
frame_duration_ms=80,
|
||||||
|
shift_ms=40,
|
||||||
|
sample_rate=16000,
|
||||||
|
sample_width=2):
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.frame_duration_ms = frame_duration_ms
|
||||||
|
self.shift_ms = shift_ms
|
||||||
|
self.remained_audio = b''
|
||||||
|
self.sample_width = sample_width # int16 = 2; float32 = 4
|
||||||
|
|
||||||
|
def frame_generator(self, audio):
|
||||||
|
"""Generates audio frames from PCM audio data.
|
||||||
|
Takes the desired frame duration in milliseconds, the PCM data, and
|
||||||
|
the sample rate.
|
||||||
|
Yields Frames of the requested duration.
|
||||||
|
"""
|
||||||
|
audio = self.remained_audio + audio
|
||||||
|
self.remained_audio = b''
|
||||||
|
|
||||||
|
n = int(self.sample_rate *
|
||||||
|
(self.frame_duration_ms / 1000.0) * self.sample_width)
|
||||||
|
shift_n = int(self.sample_rate *
|
||||||
|
(self.shift_ms / 1000.0) * self.sample_width)
|
||||||
|
offset = 0
|
||||||
|
timestamp = 0.0
|
||||||
|
duration = (float(n) / self.sample_rate) / self.sample_width
|
||||||
|
shift_duration = (float(shift_n) / self.sample_rate) / self.sample_width
|
||||||
|
while offset + n <= len(audio):
|
||||||
|
yield Frame(audio[offset:offset + n], timestamp, duration)
|
||||||
|
timestamp += shift_duration
|
||||||
|
offset += shift_n
|
||||||
|
|
||||||
|
self.remained_audio += audio[offset:]
|
@ -0,0 +1,78 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import webrtcvad
|
||||||
|
|
||||||
|
|
||||||
|
class VADAudio():
|
||||||
|
def __init__(self,
|
||||||
|
aggressiveness=2,
|
||||||
|
rate=16000,
|
||||||
|
frame_duration_ms=20,
|
||||||
|
sample_width=2,
|
||||||
|
padding_ms=200,
|
||||||
|
padding_ratio=0.9):
|
||||||
|
"""Initializes VAD with given aggressivenes and sets up internal queues"""
|
||||||
|
self.vad = webrtcvad.Vad(aggressiveness)
|
||||||
|
self.rate = rate
|
||||||
|
self.sample_width = sample_width
|
||||||
|
self.frame_duration_ms = frame_duration_ms
|
||||||
|
self._frame_length = int(rate * (frame_duration_ms / 1000.0) *
|
||||||
|
self.sample_width)
|
||||||
|
self._buffer_queue = collections.deque()
|
||||||
|
self.ring_buffer = collections.deque(maxlen=padding_ms //
|
||||||
|
frame_duration_ms)
|
||||||
|
self._ratio = padding_ratio
|
||||||
|
self.triggered = False
|
||||||
|
|
||||||
|
def add_audio(self, audio):
|
||||||
|
"""Adds new audio to internal queue"""
|
||||||
|
for x in audio:
|
||||||
|
self._buffer_queue.append(x)
|
||||||
|
|
||||||
|
def frame_generator(self):
|
||||||
|
"""Generator that yields audio frames of frame_duration_ms"""
|
||||||
|
while len(self._buffer_queue) > self._frame_length:
|
||||||
|
frame = bytearray()
|
||||||
|
for _ in range(self._frame_length):
|
||||||
|
frame.append(self._buffer_queue.popleft())
|
||||||
|
yield bytes(frame)
|
||||||
|
|
||||||
|
def vad_collector(self):
|
||||||
|
"""Generator that yields series of consecutive audio frames comprising each utterence, separated by yielding a single None.
|
||||||
|
Determines voice activity by ratio of frames in padding_ms. Uses a buffer to include padding_ms prior to being triggered.
|
||||||
|
Example: (frame, ..., frame, None, frame, ..., frame, None, ...)
|
||||||
|
|---utterence---| |---utterence---|
|
||||||
|
"""
|
||||||
|
for frame in self.frame_generator():
|
||||||
|
is_speech = self.vad.is_speech(frame, self.rate)
|
||||||
|
if not self.triggered:
|
||||||
|
self.ring_buffer.append((frame, is_speech))
|
||||||
|
num_voiced = len(
|
||||||
|
[f for f, speech in self.ring_buffer if speech])
|
||||||
|
if num_voiced > self._ratio * self.ring_buffer.maxlen:
|
||||||
|
self.triggered = True
|
||||||
|
for f, s in self.ring_buffer:
|
||||||
|
yield f
|
||||||
|
self.ring_buffer.clear()
|
||||||
|
else:
|
||||||
|
yield frame
|
||||||
|
self.ring_buffer.append((frame, is_speech))
|
||||||
|
num_unvoiced = len(
|
||||||
|
[f for f, speech in self.ring_buffer if not speech])
|
||||||
|
if num_unvoiced > self._ratio * self.ring_buffer.maxlen:
|
||||||
|
self.triggered = False
|
||||||
|
yield None
|
||||||
|
self.ring_buffer.clear()
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from paddlespeech.server.ws.asr_socket import router as asr_router
|
||||||
|
|
||||||
|
_router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_router(api_list: List):
|
||||||
|
"""setup router for fastapi
|
||||||
|
Args:
|
||||||
|
api_list (List): [asr, tts]
|
||||||
|
Returns:
|
||||||
|
APIRouter
|
||||||
|
"""
|
||||||
|
for api_name in api_list:
|
||||||
|
if api_name == 'asr':
|
||||||
|
_router.include_router(asr_router)
|
||||||
|
elif api_name == 'tts':
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return _router
|
@ -0,0 +1,100 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi import WebSocket
|
||||||
|
from fastapi import WebSocketDisconnect
|
||||||
|
from starlette.websockets import WebSocketState as WebSocketState
|
||||||
|
|
||||||
|
from paddlespeech.server.engine.engine_pool import get_engine_pool
|
||||||
|
from paddlespeech.server.utils.buffer import ChunkBuffer
|
||||||
|
from paddlespeech.server.utils.vad import VADAudio
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket('/ws/asr')
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
engine_pool = get_engine_pool()
|
||||||
|
asr_engine = engine_pool['asr']
|
||||||
|
# init buffer
|
||||||
|
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
|
||||||
|
chunk_buffer = ChunkBuffer(
|
||||||
|
sample_rate=chunk_buffer_conf['sample_rate'],
|
||||||
|
sample_width=chunk_buffer_conf['sample_width'])
|
||||||
|
# init vad
|
||||||
|
vad_conf = asr_engine.config.vad_conf
|
||||||
|
vad = VADAudio(
|
||||||
|
aggressiveness=vad_conf['aggressiveness'],
|
||||||
|
rate=vad_conf['sample_rate'],
|
||||||
|
frame_duration_ms=vad_conf['frame_duration_ms'])
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# careful here, changed the source code from starlette.websockets
|
||||||
|
assert websocket.application_state == WebSocketState.CONNECTED
|
||||||
|
message = await websocket.receive()
|
||||||
|
websocket._raise_on_disconnect(message)
|
||||||
|
if "text" in message:
|
||||||
|
message = json.loads(message["text"])
|
||||||
|
if 'signal' not in message:
|
||||||
|
resp = {"status": "ok", "message": "no valid json data"}
|
||||||
|
await websocket.send_json(resp)
|
||||||
|
|
||||||
|
if message['signal'] == 'start':
|
||||||
|
resp = {"status": "ok", "signal": "server_ready"}
|
||||||
|
# do something at begining here
|
||||||
|
await websocket.send_json(resp)
|
||||||
|
elif message['signal'] == 'end':
|
||||||
|
engine_pool = get_engine_pool()
|
||||||
|
asr_engine = engine_pool['asr']
|
||||||
|
# reset single engine for an new connection
|
||||||
|
asr_engine.reset()
|
||||||
|
resp = {"status": "ok", "signal": "finished"}
|
||||||
|
await websocket.send_json(resp)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
resp = {"status": "ok", "message": "no valid json data"}
|
||||||
|
await websocket.send_json(resp)
|
||||||
|
elif "bytes" in message:
|
||||||
|
message = message["bytes"]
|
||||||
|
|
||||||
|
# vad for input bytes audio
|
||||||
|
vad.add_audio(message)
|
||||||
|
message = b''.join(f for f in vad.vad_collector()
|
||||||
|
if f is not None)
|
||||||
|
|
||||||
|
engine_pool = get_engine_pool()
|
||||||
|
asr_engine = engine_pool['asr']
|
||||||
|
asr_results = ""
|
||||||
|
frames = chunk_buffer.frame_generator(message)
|
||||||
|
for frame in frames:
|
||||||
|
samples = np.frombuffer(frame.bytes, dtype=np.int16)
|
||||||
|
sample_rate = asr_engine.config.sample_rate
|
||||||
|
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
|
||||||
|
sample_rate)
|
||||||
|
asr_engine.run(x_chunk, x_chunk_lens)
|
||||||
|
asr_results = asr_engine.postprocess()
|
||||||
|
|
||||||
|
asr_results = asr_engine.postprocess()
|
||||||
|
resp = {'asr_results': asr_results}
|
||||||
|
|
||||||
|
await websocket.send_json(resp)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
Loading…
Reference in new issue