commit
43582f5091
@ -0,0 +1,35 @@
|
|||||||
|
# This is the parameter configuration file for PaddleSpeech Serving.
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# SERVER SETTING #
|
||||||
|
#################################################################################
|
||||||
|
host: 0.0.0.0
|
||||||
|
port: 8190
|
||||||
|
|
||||||
|
# The task format in the engin_list is: <speech task>_<engine type>
|
||||||
|
# task choices = ['asr_python']
|
||||||
|
# protocol = ['http'] (only one can be selected).
|
||||||
|
# http only support offline engine type.
|
||||||
|
protocol: 'http'
|
||||||
|
engine_list: ['text_python']
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# ENGINE CONFIG #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
################################### Text #########################################
|
||||||
|
################### text task: punc; engine_type: python #######################
|
||||||
|
text_python:
|
||||||
|
task: punc
|
||||||
|
model_type: 'ernie_linear_p3_wudao'
|
||||||
|
lang: 'zh'
|
||||||
|
sample_rate: 16000
|
||||||
|
cfg_path: # [optional]
|
||||||
|
ckpt_path: # [optional]
|
||||||
|
vocab_file: # [optional]
|
||||||
|
device: 'cpu' # set 'gpu:id' or 'cpu'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from paddlespeech.cli.log import logger
|
||||||
|
from paddlespeech.server.bin.paddlespeech_server import ServerExecutor
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog='paddlespeech_server.start', add_help=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_file",
|
||||||
|
action="store",
|
||||||
|
help="yaml file of the app",
|
||||||
|
default=None,
|
||||||
|
required=True)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--log_file",
|
||||||
|
action="store",
|
||||||
|
help="log file",
|
||||||
|
default="./log/paddlespeech.log")
|
||||||
|
logger.info("start to parse the args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logger.info("start to launch the punctuation server")
|
||||||
|
punc_server = ServerExecutor()
|
||||||
|
punc_server(config_file=args.config_file, log_file=args.log_file)
|
@ -0,0 +1,5 @@
|
|||||||
|
export CUDA_VISIBLE_DEVICE=0,1,2,3
|
||||||
|
|
||||||
|
nohup python3 punc_server.py --config_file conf/punc_application.yaml > punc.log 2>&1 &
|
||||||
|
|
||||||
|
nohup python3 streaming_asr_server.py --config_file conf/ws_conformer_application.yaml > streaming_asr.log 2>&1 &
|
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from paddlespeech.cli.log import logger
|
||||||
|
from paddlespeech.server.bin.paddlespeech_server import ServerExecutor
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog='paddlespeech_server.start', add_help=True)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config_file",
|
||||||
|
action="store",
|
||||||
|
help="yaml file of the app",
|
||||||
|
default=None,
|
||||||
|
required=True)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--log_file",
|
||||||
|
action="store",
|
||||||
|
help="log file",
|
||||||
|
default="./log/paddlespeech.log")
|
||||||
|
logger.info("start to parse the args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logger.info("start to launch the streaming asr server")
|
||||||
|
streaming_asr_server = ServerExecutor()
|
||||||
|
streaming_asr_server(config_file=args.config_file, log_file=args.log_file)
|
@ -1,5 +1,8 @@
|
|||||||
# download the test wav
|
# download the test wav
|
||||||
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
|
||||||
|
|
||||||
# read the wav and pass it to service
|
# read the wav and pass it to only streaming asr service
|
||||||
python3 websocket_client.py --wavfile ./zh.wav
|
python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav
|
||||||
|
|
||||||
|
# read the wav and call streaming and punc service
|
||||||
|
python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav
|
||||||
|
@ -0,0 +1,32 @@
|
|||||||
|
# This is the parameter configuration file for PaddleSpeech Serving.
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# SERVER SETTING #
|
||||||
|
#################################################################################
|
||||||
|
host: 0.0.0.0
|
||||||
|
port: 8090
|
||||||
|
|
||||||
|
# The task format in the engin_list is: <speech task>_<engine type>
|
||||||
|
# protocol = ['http'] (only one can be selected).
|
||||||
|
# http only support offline engine type.
|
||||||
|
protocol: 'http'
|
||||||
|
engine_list: ['vector_python']
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# ENGINE CONFIG #
|
||||||
|
#################################################################################
|
||||||
|
|
||||||
|
################################### Vector ######################################
|
||||||
|
################### Vector task: spk; engine_type: python #######################
|
||||||
|
vector_python:
|
||||||
|
task: spk
|
||||||
|
model_type: 'ecapatdnn_voxceleb12'
|
||||||
|
sample_rate: 16000
|
||||||
|
cfg_path: # [optional]
|
||||||
|
ckpt_path: # [optional]
|
||||||
|
device: # set 'gpu:id' or 'cpu'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,200 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import io
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from paddleaudio.backends import load as load_audio
|
||||||
|
from paddleaudio.compliance.librosa import melspectrogram
|
||||||
|
from paddlespeech.cli.log import logger
|
||||||
|
from paddlespeech.cli.vector.infer import VectorExecutor
|
||||||
|
from paddlespeech.server.engine.base_engine import BaseEngine
|
||||||
|
from paddlespeech.vector.io.batch import feature_normalize
|
||||||
|
|
||||||
|
|
||||||
|
class PaddleVectorConnectionHandler:
|
||||||
|
def __init__(self, vector_engine):
|
||||||
|
"""The PaddleSpeech Vector Server Connection Handler
|
||||||
|
This connection process every server request
|
||||||
|
Args:
|
||||||
|
vector_engine (VectorEngine): The Vector engine
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
logger.info(
|
||||||
|
"Create PaddleVectorConnectionHandler to process the vector request")
|
||||||
|
self.vector_engine = vector_engine
|
||||||
|
self.executor = self.vector_engine.executor
|
||||||
|
self.task = self.vector_engine.executor.task
|
||||||
|
self.model = self.vector_engine.executor.model
|
||||||
|
self.config = self.vector_engine.executor.config
|
||||||
|
|
||||||
|
self._inputs = OrderedDict()
|
||||||
|
self._outputs = OrderedDict()
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def run(self, audio_data, task="spk"):
|
||||||
|
"""The connection process the http request audio
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data (bytes): base64.b64decode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: the punctuation text
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"start to extract the do vector {self.task} from the http request")
|
||||||
|
if self.task == "spk" and task == "spk":
|
||||||
|
embedding = self.extract_audio_embedding(audio_data)
|
||||||
|
return embedding
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"The request task is not matched with server model task")
|
||||||
|
logger.error(
|
||||||
|
f"The server model task is: {self.task}, but the request task is: {task}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return np.array([
|
||||||
|
0.0,
|
||||||
|
])
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def get_enroll_test_score(self, enroll_audio, test_audio):
|
||||||
|
"""Get the enroll and test audio score
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enroll_audio (str): the base64 format enroll audio
|
||||||
|
test_audio (str): the base64 format test audio
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: the score between enroll and test audio
|
||||||
|
"""
|
||||||
|
logger.info("start to extract the enroll audio embedding")
|
||||||
|
enroll_emb = self.extract_audio_embedding(enroll_audio)
|
||||||
|
|
||||||
|
logger.info("start to extract the test audio embedding")
|
||||||
|
test_emb = self.extract_audio_embedding(test_audio)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"start to get the score between the enroll and test embedding")
|
||||||
|
score = self.executor.get_embeddings_score(enroll_emb, test_emb)
|
||||||
|
|
||||||
|
logger.info(f"get the enroll vs test score: {score}")
|
||||||
|
return score
|
||||||
|
|
||||||
|
@paddle.no_grad()
|
||||||
|
def extract_audio_embedding(self, audio: str, sample_rate: int=16000):
|
||||||
|
"""extract the audio embedding
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (str): the audio data
|
||||||
|
sample_rate (int, optional): the audio sample rate. Defaults to 16000.
|
||||||
|
"""
|
||||||
|
# we can not reuse the cache io.BytesIO(audio) data,
|
||||||
|
# because the soundfile will change the io.BytesIO(audio) to the end
|
||||||
|
# thus we should convert the base64 string to io.BytesIO when we need the audio data
|
||||||
|
if not self.executor._check(io.BytesIO(audio), sample_rate):
|
||||||
|
logger.info("check the audio sample rate occurs error")
|
||||||
|
return np.array([0.0])
|
||||||
|
|
||||||
|
waveform, sr = load_audio(io.BytesIO(audio))
|
||||||
|
logger.info(f"load the audio sample points, shape is: {waveform.shape}")
|
||||||
|
|
||||||
|
# stage 2: get the audio feat
|
||||||
|
# Note: Now we only support fbank feature
|
||||||
|
try:
|
||||||
|
feats = melspectrogram(
|
||||||
|
x=waveform,
|
||||||
|
sr=self.config.sr,
|
||||||
|
n_mels=self.config.n_mels,
|
||||||
|
window_size=self.config.window_size,
|
||||||
|
hop_length=self.config.hop_size)
|
||||||
|
logger.info(f"extract the audio feats, shape is: {feats.shape}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"feats occurs exception {e}")
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
feats = paddle.to_tensor(feats).unsqueeze(0)
|
||||||
|
# in inference period, the lengths is all one without padding
|
||||||
|
lengths = paddle.ones([1])
|
||||||
|
|
||||||
|
# stage 3: we do feature normalize,
|
||||||
|
# Now we assume that the feats must do normalize
|
||||||
|
feats = feature_normalize(feats, mean_norm=True, std_norm=False)
|
||||||
|
|
||||||
|
# stage 4: store the feats and length in the _inputs,
|
||||||
|
# which will be used in other function
|
||||||
|
logger.info(f"feats shape: {feats.shape}")
|
||||||
|
logger.info("audio extract the feats success")
|
||||||
|
|
||||||
|
logger.info("start to extract the audio embedding")
|
||||||
|
embedding = self.model.backbone(feats, lengths).squeeze().numpy()
|
||||||
|
logger.info(f"embedding size: {embedding.shape}")
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
class VectorServerExecutor(VectorExecutor):
|
||||||
|
def __init__(self):
|
||||||
|
"""The wrapper for TextEcutor
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class VectorEngine(BaseEngine):
|
||||||
|
def __init__(self):
|
||||||
|
"""The Vector Engine
|
||||||
|
"""
|
||||||
|
super(VectorEngine, self).__init__()
|
||||||
|
logger.info("Create the VectorEngine Instance")
|
||||||
|
|
||||||
|
def init(self, config: dict):
|
||||||
|
"""Init the Vector Engine
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): The server configuation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: The engine instance flag
|
||||||
|
"""
|
||||||
|
logger.info("Init the vector engine")
|
||||||
|
try:
|
||||||
|
self.config = config
|
||||||
|
if self.config.device:
|
||||||
|
self.device = self.config.device
|
||||||
|
else:
|
||||||
|
self.device = paddle.get_device()
|
||||||
|
|
||||||
|
paddle.set_device(self.device)
|
||||||
|
logger.info(f"Vector Engine set the device: {self.device}")
|
||||||
|
except BaseException as e:
|
||||||
|
logger.error(
|
||||||
|
"Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
|
||||||
|
)
|
||||||
|
logger.error("Initialize Vector server engine Failed on device: %s."
|
||||||
|
% (self.device))
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.executor = VectorServerExecutor()
|
||||||
|
|
||||||
|
self.executor._init_from_path(
|
||||||
|
model_type=config.model_type,
|
||||||
|
cfg_path=config.cfg_path,
|
||||||
|
ckpt_path=config.ckpt_path,
|
||||||
|
task=config.task)
|
||||||
|
|
||||||
|
logger.info("Init the Vector engine successfully")
|
||||||
|
return True
|
@ -0,0 +1,151 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import base64
|
||||||
|
import traceback
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from paddlespeech.cli.log import logger
|
||||||
|
from paddlespeech.server.engine.engine_pool import get_engine_pool
|
||||||
|
from paddlespeech.server.engine.vector.python.vector_engine import PaddleVectorConnectionHandler
|
||||||
|
from paddlespeech.server.restful.request import VectorRequest
|
||||||
|
from paddlespeech.server.restful.request import VectorScoreRequest
|
||||||
|
from paddlespeech.server.restful.response import ErrorResponse
|
||||||
|
from paddlespeech.server.restful.response import VectorResponse
|
||||||
|
from paddlespeech.server.restful.response import VectorScoreResponse
|
||||||
|
from paddlespeech.server.utils.errors import ErrorCode
|
||||||
|
from paddlespeech.server.utils.errors import failed_response
|
||||||
|
from paddlespeech.server.utils.exception import ServerBaseException
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get('/paddlespeech/vector/help')
|
||||||
|
def help():
|
||||||
|
"""help
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
json: The /paddlespeech/vector api response content
|
||||||
|
"""
|
||||||
|
response = {
|
||||||
|
"success": "True",
|
||||||
|
"code": 200,
|
||||||
|
"message": {
|
||||||
|
"global": "success"
|
||||||
|
},
|
||||||
|
"vector": [2.3, 3.5, 5.5, 6.2, 2.8, 1.2, 0.3, 3.6]
|
||||||
|
}
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/paddlespeech/vector", response_model=Union[VectorResponse, ErrorResponse])
|
||||||
|
def vector(request_body: VectorRequest):
|
||||||
|
"""vector api
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_body (VectorRequest): the vector request body
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
json: the vector response body
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 1. get the audio data
|
||||||
|
# the audio must be base64 format
|
||||||
|
audio_data = base64.b64decode(request_body.audio)
|
||||||
|
|
||||||
|
# 2. get single engine from engine pool
|
||||||
|
# and we use the vector_engine to create an connection handler to process the request
|
||||||
|
engine_pool = get_engine_pool()
|
||||||
|
vector_engine = engine_pool['vector']
|
||||||
|
connection_handler = PaddleVectorConnectionHandler(vector_engine)
|
||||||
|
|
||||||
|
# 3. we use the connection handler to process the audio
|
||||||
|
audio_vec = connection_handler.run(audio_data, request_body.task)
|
||||||
|
|
||||||
|
# 4. we need the result of the vector instance be numpy.ndarray
|
||||||
|
if not isinstance(audio_vec, np.ndarray):
|
||||||
|
logger.error(
|
||||||
|
f"the vector type is not numpy.array, that is: {type(audio_vec)}"
|
||||||
|
)
|
||||||
|
error_reponse = ErrorResponse()
|
||||||
|
error_reponse.message.description = f"the vector type is not numpy.array, that is: {type(audio_vec)}"
|
||||||
|
return error_reponse
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"success": True,
|
||||||
|
"code": 200,
|
||||||
|
"message": {
|
||||||
|
"description": "success"
|
||||||
|
},
|
||||||
|
"result": {
|
||||||
|
"vec": audio_vec.tolist()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except ServerBaseException as e:
|
||||||
|
response = failed_response(e.error_code, e.msg)
|
||||||
|
except BaseException:
|
||||||
|
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/paddlespeech/vector/score",
|
||||||
|
response_model=Union[VectorScoreResponse, ErrorResponse])
|
||||||
|
def score(request_body: VectorScoreRequest):
|
||||||
|
"""vector api
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_body (VectorScoreRequest): the punctuation request body
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
json: the punctuation response body
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 1. get the audio data
|
||||||
|
# the audio must be base64 format
|
||||||
|
enroll_data = base64.b64decode(request_body.enroll_audio)
|
||||||
|
test_data = base64.b64decode(request_body.test_audio)
|
||||||
|
|
||||||
|
# 2. get single engine from engine pool
|
||||||
|
# and we use the vector_engine to create an connection handler to process the request
|
||||||
|
engine_pool = get_engine_pool()
|
||||||
|
vector_engine = engine_pool['vector']
|
||||||
|
connection_handler = PaddleVectorConnectionHandler(vector_engine)
|
||||||
|
|
||||||
|
# 3. we use the connection handler to process the audio
|
||||||
|
score = connection_handler.get_enroll_test_score(enroll_data, test_data)
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"success": True,
|
||||||
|
"code": 200,
|
||||||
|
"message": {
|
||||||
|
"description": "success"
|
||||||
|
},
|
||||||
|
"result": {
|
||||||
|
"score": score
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except ServerBaseException as e:
|
||||||
|
response = failed_response(e.error_code, e.msg)
|
||||||
|
except BaseException:
|
||||||
|
response = failed_response(ErrorCode.SERVER_UNKOWN_ERR)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
return response
|
@ -1,13 +1,14 @@
|
|||||||
# Deepspeech2 Streaming ASR
|
# Deepspeech2 Streaming ASR
|
||||||
|
|
||||||
* websocket
|
## Examples
|
||||||
Streaming ASR with websocket.
|
|
||||||
|
|
||||||
* aishell
|
* `websocket` - Streaming ASR with websocket.
|
||||||
Streaming Decoding under aishell dataset, for local WER test and so on.
|
|
||||||
|
* `aishell` - Streaming Decoding under aishell dataset, for local WER test.
|
||||||
|
|
||||||
## More
|
## More
|
||||||
The below is for developing and offline testing:
|
|
||||||
|
> The below is for developing and offline testing. Do not run it only if you know what it is.
|
||||||
* nnet
|
* nnet
|
||||||
* feat
|
* feat
|
||||||
* decoder
|
* decoder
|
||||||
|
@ -0,0 +1,142 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// todo refactor, repalce with gtest
|
||||||
|
|
||||||
|
#include "base/flags.h"
|
||||||
|
#include "base/log.h"
|
||||||
|
#include "kaldi/feat/wave-reader.h"
|
||||||
|
#include "kaldi/util/kaldi-io.h"
|
||||||
|
#include "kaldi/util/table-types.h"
|
||||||
|
|
||||||
|
#include "frontend/audio/audio_cache.h"
|
||||||
|
#include "frontend/audio/data_cache.h"
|
||||||
|
#include "frontend/audio/fbank.h"
|
||||||
|
#include "frontend/audio/feature_cache.h"
|
||||||
|
#include "frontend/audio/frontend_itf.h"
|
||||||
|
#include "frontend/audio/normalizer.h"
|
||||||
|
|
||||||
|
DEFINE_string(wav_rspecifier, "", "test wav scp path");
|
||||||
|
DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
|
||||||
|
DEFINE_string(cmvn_file, "", "read cmvn");
|
||||||
|
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
|
||||||
|
DEFINE_int32(num_bins, 161, "fbank num bins");
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
gflags::ParseCommandLineFlags(&argc, &argv, false);
|
||||||
|
google::InitGoogleLogging(argv[0]);
|
||||||
|
|
||||||
|
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
|
||||||
|
FLAGS_wav_rspecifier);
|
||||||
|
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
|
||||||
|
|
||||||
|
int32 num_done = 0, num_err = 0;
|
||||||
|
|
||||||
|
// feature pipeline: wave cache --> povey window
|
||||||
|
// -->fbank --> global cmvn -> feat cache
|
||||||
|
|
||||||
|
std::unique_ptr<ppspeech::FrontendInterface> data_source(
|
||||||
|
new ppspeech::AudioCache(3600 * 1600, false));
|
||||||
|
|
||||||
|
ppspeech::FbankOptions opt;
|
||||||
|
opt.fbank_opts.frame_opts.frame_length_ms = 25;
|
||||||
|
opt.fbank_opts.frame_opts.frame_shift_ms = 10;
|
||||||
|
opt.streaming_chunk = FLAGS_streaming_chunk;
|
||||||
|
opt.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
|
||||||
|
opt.fbank_opts.frame_opts.dither = 0.0;
|
||||||
|
|
||||||
|
std::unique_ptr<ppspeech::FrontendInterface> fbank(
|
||||||
|
new ppspeech::Fbank(opt, std::move(data_source)));
|
||||||
|
|
||||||
|
std::unique_ptr<ppspeech::FrontendInterface> cmvn(
|
||||||
|
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank)));
|
||||||
|
|
||||||
|
ppspeech::FeatureCacheOptions feat_cache_opts;
|
||||||
|
// the feature cache output feature chunk by chunk.
|
||||||
|
// frame_chunk_size : num frame of a chunk.
|
||||||
|
// frame_chunk_stride: chunk sliding window stride.
|
||||||
|
feat_cache_opts.frame_chunk_stride = 1;
|
||||||
|
feat_cache_opts.frame_chunk_size = 1;
|
||||||
|
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
|
||||||
|
LOG(INFO) << "feat dim: " << feature_cache.Dim();
|
||||||
|
|
||||||
|
int sample_rate = 16000;
|
||||||
|
float streaming_chunk = FLAGS_streaming_chunk;
|
||||||
|
int chunk_sample_size = streaming_chunk * sample_rate;
|
||||||
|
LOG(INFO) << "sr: " << sample_rate;
|
||||||
|
LOG(INFO) << "chunk size (s): " << streaming_chunk;
|
||||||
|
LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
|
||||||
|
|
||||||
|
for (; !wav_reader.Done(); wav_reader.Next()) {
|
||||||
|
std::string utt = wav_reader.Key();
|
||||||
|
const kaldi::WaveData& wave_data = wav_reader.Value();
|
||||||
|
LOG(INFO) << "process utt: " << utt;
|
||||||
|
|
||||||
|
int32 this_channel = 0;
|
||||||
|
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
|
||||||
|
this_channel);
|
||||||
|
int tot_samples = waveform.Dim();
|
||||||
|
LOG(INFO) << "wav len (sample): " << tot_samples;
|
||||||
|
|
||||||
|
int sample_offset = 0;
|
||||||
|
std::vector<kaldi::Vector<BaseFloat>> feats;
|
||||||
|
int feature_rows = 0;
|
||||||
|
while (sample_offset < tot_samples) {
|
||||||
|
int cur_chunk_size =
|
||||||
|
std::min(chunk_sample_size, tot_samples - sample_offset);
|
||||||
|
|
||||||
|
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
|
||||||
|
for (int i = 0; i < cur_chunk_size; ++i) {
|
||||||
|
wav_chunk(i) = waveform(sample_offset + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
kaldi::Vector<BaseFloat> features;
|
||||||
|
feature_cache.Accept(wav_chunk);
|
||||||
|
if (cur_chunk_size < chunk_sample_size) {
|
||||||
|
feature_cache.SetFinished();
|
||||||
|
}
|
||||||
|
bool flag = true;
|
||||||
|
do {
|
||||||
|
flag = feature_cache.Read(&features);
|
||||||
|
feats.push_back(features);
|
||||||
|
feature_rows += features.Dim() / feature_cache.Dim();
|
||||||
|
} while (flag == true && features.Dim() != 0);
|
||||||
|
sample_offset += cur_chunk_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
int cur_idx = 0;
|
||||||
|
kaldi::Matrix<kaldi::BaseFloat> features(feature_rows,
|
||||||
|
feature_cache.Dim());
|
||||||
|
for (auto feat : feats) {
|
||||||
|
int num_rows = feat.Dim() / feature_cache.Dim();
|
||||||
|
for (int row_idx = 0; row_idx < num_rows; ++row_idx) {
|
||||||
|
for (size_t col_idx = 0; col_idx < feature_cache.Dim();
|
||||||
|
++col_idx) {
|
||||||
|
features(cur_idx, col_idx) =
|
||||||
|
feat(row_idx * feature_cache.Dim() + col_idx);
|
||||||
|
}
|
||||||
|
++cur_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
feat_writer.Write(utt, features);
|
||||||
|
feature_cache.Reset();
|
||||||
|
|
||||||
|
if (num_done % 50 == 0 && num_done != 0)
|
||||||
|
KALDI_VLOG(2) << "Processed " << num_done << " utterances";
|
||||||
|
num_done++;
|
||||||
|
}
|
||||||
|
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
|
||||||
|
<< " with errors.";
|
||||||
|
return (num_done != 0 ? 0 : 1);
|
||||||
|
}
|
@ -0,0 +1,2 @@
|
|||||||
|
reference:
|
||||||
|
this patch is from WeNet wenet/runtime/core/patch
|
Loading…
Reference in new issue