commit
c552c0877f
@ -1,2 +0,0 @@
|
||||
.eggs
|
||||
*.wav
|
@ -1,9 +0,0 @@
|
||||
# Changelog
|
||||
|
||||
Date: 2022-3-15, Author: Xiaojie Chen.
|
||||
- kaldi and librosa mfcc, fbank, spectrogram.
|
||||
- unit test and benchmark.
|
||||
|
||||
Date: 2022-2-25, Author: Hui Zhang.
|
||||
- Refactor architecture.
|
||||
- dtw distance and mcd style dtw.
|
@ -1,7 +0,0 @@
|
||||
# PaddleAudio
|
||||
|
||||
PaddleAudio is an audio library for PaddlePaddle.
|
||||
|
||||
## Install
|
||||
|
||||
`pip install .`
|
@ -1,19 +0,0 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line.
|
||||
SPHINXOPTS =
|
||||
SPHINXBUILD = sphinx-build
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
@ -1,24 +0,0 @@
|
||||
# Build docs for PaddleAudio
|
||||
|
||||
Execute the following steps in **current directory**.
|
||||
|
||||
## 1. Install
|
||||
|
||||
`pip install Sphinx sphinx_rtd_theme`
|
||||
|
||||
|
||||
## 2. Generate API docs
|
||||
|
||||
Generate API docs from doc string.
|
||||
|
||||
`sphinx-apidoc -fMeT -o source ../paddleaudio ../paddleaudio/utils --templatedir source/_templates`
|
||||
|
||||
|
||||
## 3. Build
|
||||
|
||||
`sphinx-build source _html`
|
||||
|
||||
|
||||
## 4. Preview
|
||||
|
||||
Open `_html/index.html` for page preview.
|
Before Width: | Height: | Size: 4.9 KiB |
@ -1,35 +0,0 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=source
|
||||
set BUILDDIR=build
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.http://sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
|
||||
|
||||
:end
|
||||
popd
|
@ -1,60 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
'''
|
||||
This module is used to store environmental variables in PaddleAudio.
|
||||
PPAUDIO_HOME --> the root directory for storing PaddleAudio related data. Default to ~/.paddleaudio. Users can change the
|
||||
├ default value through the PPAUDIO_HOME environment variable.
|
||||
├─ MODEL_HOME --> Store model files.
|
||||
└─ DATA_HOME --> Store automatically downloaded datasets.
|
||||
'''
|
||||
import os
|
||||
|
||||
__all__ = [
|
||||
'USER_HOME',
|
||||
'PPAUDIO_HOME',
|
||||
'MODEL_HOME',
|
||||
'DATA_HOME',
|
||||
]
|
||||
|
||||
|
||||
def _get_user_home():
|
||||
return os.path.expanduser('~')
|
||||
|
||||
|
||||
def _get_ppaudio_home():
|
||||
if 'PPAUDIO_HOME' in os.environ:
|
||||
home_path = os.environ['PPAUDIO_HOME']
|
||||
if os.path.exists(home_path):
|
||||
if os.path.isdir(home_path):
|
||||
return home_path
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'The environment variable PPAUDIO_HOME {} is not a directory.'.
|
||||
format(home_path))
|
||||
else:
|
||||
return home_path
|
||||
return os.path.join(_get_user_home(), '.paddleaudio')
|
||||
|
||||
|
||||
def _get_sub_home(directory):
|
||||
home = os.path.join(_get_ppaudio_home(), directory)
|
||||
if not os.path.exists(home):
|
||||
os.makedirs(home)
|
||||
return home
|
||||
|
||||
|
||||
USER_HOME = _get_user_home()
|
||||
PPAUDIO_HOME = _get_ppaudio_home()
|
||||
MODEL_HOME = _get_sub_home('models')
|
||||
DATA_HOME = _get_sub_home('datasets')
|
@ -1,99 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import glob
|
||||
import os
|
||||
|
||||
import setuptools
|
||||
from setuptools.command.install import install
|
||||
from setuptools.command.test import test
|
||||
|
||||
# set the version here
|
||||
VERSION = '0.0.0'
|
||||
|
||||
|
||||
# Inspired by the example at https://pytest.org/latest/goodpractises.html
|
||||
class TestCommand(test):
|
||||
def finalize_options(self):
|
||||
test.finalize_options(self)
|
||||
self.test_args = []
|
||||
self.test_suite = True
|
||||
|
||||
def run(self):
|
||||
self.run_benchmark()
|
||||
super(TestCommand, self).run()
|
||||
|
||||
def run_tests(self):
|
||||
# Run nose ensuring that argv simulates running nosetests directly
|
||||
import nose
|
||||
nose.run_exit(argv=['nosetests', '-w', 'tests'])
|
||||
|
||||
def run_benchmark(self):
|
||||
for benchmark_item in glob.glob('tests/benchmark/*py'):
|
||||
os.system(f'pytest {benchmark_item}')
|
||||
|
||||
|
||||
class InstallCommand(install):
|
||||
def run(self):
|
||||
install.run(self)
|
||||
|
||||
|
||||
def write_version_py(filename='paddleaudio/__init__.py'):
|
||||
with open(filename, "a") as f:
|
||||
f.write(f"__version__ = '{VERSION}'")
|
||||
|
||||
|
||||
def remove_version_py(filename='paddleaudio/__init__.py'):
|
||||
with open(filename, "r") as f:
|
||||
lines = f.readlines()
|
||||
with open(filename, "w") as f:
|
||||
for line in lines:
|
||||
if "__version__" not in line:
|
||||
f.write(line)
|
||||
|
||||
|
||||
remove_version_py()
|
||||
write_version_py()
|
||||
|
||||
setuptools.setup(
|
||||
name="paddleaudio",
|
||||
version=VERSION,
|
||||
author="",
|
||||
author_email="",
|
||||
description="PaddleAudio, in development",
|
||||
long_description="",
|
||||
long_description_content_type="text/markdown",
|
||||
url="",
|
||||
packages=setuptools.find_packages(include=['paddleaudio*']),
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
python_requires='>=3.6',
|
||||
install_requires=[
|
||||
'numpy >= 1.15.0', 'scipy >= 1.0.0', 'resampy >= 0.2.2',
|
||||
'soundfile >= 0.9.0', 'colorlog', 'pathos == 0.2.8'
|
||||
],
|
||||
extras_require={
|
||||
'test': [
|
||||
'nose', 'librosa==0.8.1', 'soundfile==0.10.3.post1',
|
||||
'torchaudio==0.10.2', 'pytest-benchmark'
|
||||
],
|
||||
},
|
||||
cmdclass={
|
||||
'install': InstallCommand,
|
||||
'test': TestCommand,
|
||||
}, )
|
||||
|
||||
remove_version_py()
|
@ -1,49 +0,0 @@
|
||||
# This is the parameter configuration file for PaddleSpeech Serving.
|
||||
|
||||
#################################################################################
|
||||
# SERVER SETTING #
|
||||
#################################################################################
|
||||
host: 0.0.0.0
|
||||
port: 8090
|
||||
|
||||
# The task format in the engin_list is: <speech task>_<engine type>
|
||||
# task choices = ['asr_online', 'tts_online']
|
||||
# protocol = ['websocket', 'http'] (only one can be selected).
|
||||
# websocket only support online engine type.
|
||||
protocol: 'websocket'
|
||||
engine_list: ['asr_online']
|
||||
|
||||
|
||||
#################################################################################
|
||||
# ENGINE CONFIG #
|
||||
#################################################################################
|
||||
|
||||
################################### ASR #########################################
|
||||
################### speech task: asr; engine_type: online #######################
|
||||
asr_online:
|
||||
model_type: 'deepspeech2online_aishell'
|
||||
am_model: # the pdmodel file of am static model [optional]
|
||||
am_params: # the pdiparams file of am static model [optional]
|
||||
lang: 'zh'
|
||||
sample_rate: 16000
|
||||
cfg_path:
|
||||
decode_method:
|
||||
num_decoding_left_chunks:
|
||||
force_yes: True
|
||||
device: # cpu or gpu:id
|
||||
|
||||
am_predictor_conf:
|
||||
device: # set 'gpu:id' or 'cpu'
|
||||
switch_ir_optim: True
|
||||
glog_info: False # True -> print glog
|
||||
summary: True # False -> do not show predictor config
|
||||
|
||||
chunk_buffer_conf:
|
||||
frame_duration_ms: 85
|
||||
shift_ms: 40
|
||||
sample_rate: 16000
|
||||
sample_width: 2
|
||||
window_n: 7 # frame
|
||||
shift_n: 4 # frame
|
||||
window_ms: 25 # ms
|
||||
shift_ms: 10 # ms
|
@ -0,0 +1,84 @@
|
||||
# This is the parameter configuration file for PaddleSpeech Serving.
|
||||
|
||||
#################################################################################
|
||||
# SERVER SETTING #
|
||||
#################################################################################
|
||||
host: 0.0.0.0
|
||||
port: 8090
|
||||
|
||||
# The task format in the engin_list is: <speech task>_<engine type>
|
||||
# task choices = ['asr_online-inference', 'asr_online-onnx']
|
||||
# protocol = ['websocket'] (only one can be selected).
|
||||
# websocket only support online engine type.
|
||||
protocol: 'websocket'
|
||||
engine_list: ['asr_online-inference']
|
||||
|
||||
|
||||
#################################################################################
|
||||
# ENGINE CONFIG #
|
||||
#################################################################################
|
||||
|
||||
################################### ASR #########################################
|
||||
################### speech task: asr; engine_type: online-inference #######################
|
||||
asr_online-inference:
|
||||
model_type: 'deepspeech2online_aishell'
|
||||
am_model: # the pdmodel file of am static model [optional]
|
||||
am_params: # the pdiparams file of am static model [optional]
|
||||
lang: 'zh'
|
||||
sample_rate: 16000
|
||||
cfg_path:
|
||||
decode_method:
|
||||
num_decoding_left_chunks:
|
||||
force_yes: True
|
||||
device: 'cpu' # cpu or gpu:id
|
||||
|
||||
am_predictor_conf:
|
||||
device: # set 'gpu:id' or 'cpu'
|
||||
switch_ir_optim: True
|
||||
glog_info: False # True -> print glog
|
||||
summary: True # False -> do not show predictor config
|
||||
|
||||
chunk_buffer_conf:
|
||||
frame_duration_ms: 80
|
||||
shift_ms: 40
|
||||
sample_rate: 16000
|
||||
sample_width: 2
|
||||
window_n: 7 # frame
|
||||
shift_n: 4 # frame
|
||||
window_ms: 20 # ms
|
||||
shift_ms: 10 # ms
|
||||
|
||||
|
||||
|
||||
################################### ASR #########################################
|
||||
################### speech task: asr; engine_type: online-onnx #######################
|
||||
asr_online-onnx:
|
||||
model_type: 'deepspeech2online_aishell'
|
||||
am_model: # the pdmodel file of onnx am static model [optional]
|
||||
am_params: # the pdiparams file of am static model [optional]
|
||||
lang: 'zh'
|
||||
sample_rate: 16000
|
||||
cfg_path:
|
||||
decode_method:
|
||||
num_decoding_left_chunks:
|
||||
force_yes: True
|
||||
device: 'cpu' # cpu or gpu:id
|
||||
|
||||
# https://onnxruntime.ai/docs/api/python/api_summary.html#inferencesession
|
||||
am_predictor_conf:
|
||||
device: 'cpu' # set 'gpu:id' or 'cpu'
|
||||
graph_optimization_level: 0
|
||||
intra_op_num_threads: 0 # Sets the number of threads used to parallelize the execution within nodes.
|
||||
inter_op_num_threads: 0 # Sets the number of threads used to parallelize the execution of the graph (across nodes).
|
||||
log_severity_level: 2 # Log severity level. Applies to session load, initialization, etc. 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2.
|
||||
log_verbosity_level: 0 # VLOG level if DEBUG build and session_log_severity_level is 0. Applies to session load, initialization, etc. Default is 0.
|
||||
|
||||
chunk_buffer_conf:
|
||||
frame_duration_ms: 85
|
||||
shift_ms: 40
|
||||
sample_rate: 16000
|
||||
sample_width: 2
|
||||
window_n: 7 # frame
|
||||
shift_n: 4 # frame
|
||||
window_ms: 25 # ms
|
||||
shift_ms: 10 # ms
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,530 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
from typing import ByteString
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from numpy import float32
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.cli.asr.infer import ASRExecutor
|
||||
from paddlespeech.cli.log import logger
|
||||
from paddlespeech.cli.utils import MODEL_HOME
|
||||
from paddlespeech.resource import CommonTaskResource
|
||||
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||
from paddlespeech.s2t.modules.ctc import CTCDecoder
|
||||
from paddlespeech.s2t.transform.transformation import Transformation
|
||||
from paddlespeech.s2t.utils.utility import UpdateConfig
|
||||
from paddlespeech.server.engine.base_engine import BaseEngine
|
||||
from paddlespeech.server.utils import onnx_infer
|
||||
|
||||
__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine']
|
||||
|
||||
|
||||
# ASR server connection process class
|
||||
class PaddleASRConnectionHanddler:
|
||||
def __init__(self, asr_engine):
|
||||
"""Init a Paddle ASR Connection Handler instance
|
||||
|
||||
Args:
|
||||
asr_engine (ASREngine): the global asr engine
|
||||
"""
|
||||
super().__init__()
|
||||
logger.info(
|
||||
"create an paddle asr connection handler to process the websocket connection"
|
||||
)
|
||||
self.config = asr_engine.config # server config
|
||||
self.model_config = asr_engine.executor.config
|
||||
self.asr_engine = asr_engine
|
||||
|
||||
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
|
||||
self.model_type = self.asr_engine.executor.model_type
|
||||
self.sample_rate = self.asr_engine.executor.sample_rate
|
||||
# tokens to text
|
||||
self.text_feature = self.asr_engine.executor.text_feature
|
||||
|
||||
# extract feat, new only fbank in conformer model
|
||||
self.preprocess_conf = self.model_config.preprocess_config
|
||||
self.preprocess_args = {"train": False}
|
||||
self.preprocessing = Transformation(self.preprocess_conf)
|
||||
|
||||
# frame window and frame shift, in samples unit
|
||||
self.win_length = self.preprocess_conf.process[0]['win_length']
|
||||
self.n_shift = self.preprocess_conf.process[0]['n_shift']
|
||||
|
||||
assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, (
|
||||
self.sample_rate, self.preprocess_conf.process[0]['fs'])
|
||||
self.frame_shift_in_ms = int(
|
||||
self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000)
|
||||
|
||||
self.continuous_decoding = self.config.get("continuous_decoding", False)
|
||||
self.init_decoder()
|
||||
self.reset()
|
||||
|
||||
def init_decoder(self):
|
||||
if "deepspeech2" in self.model_type:
|
||||
assert self.continuous_decoding is False, "ds2 model not support endpoint"
|
||||
self.am_predictor = self.asr_engine.executor.am_predictor
|
||||
|
||||
self.decoder = CTCDecoder(
|
||||
odim=self.model_config.output_dim, # <blank> is in vocab
|
||||
enc_n_units=self.model_config.rnn_layer_size * 2,
|
||||
blank_id=self.model_config.blank_id,
|
||||
dropout_rate=0.0,
|
||||
reduction=True, # sum
|
||||
batch_average=True, # sum / batch_size
|
||||
grad_norm_type=self.model_config.get('ctc_grad_norm_type',
|
||||
None))
|
||||
|
||||
cfg = self.model_config.decode
|
||||
decode_batch_size = 1 # for online
|
||||
self.decoder.init_decoder(
|
||||
decode_batch_size, self.text_feature.vocab_list,
|
||||
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
|
||||
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
|
||||
cfg.num_proc_bsearch)
|
||||
else:
|
||||
raise ValueError(f"Not supported: {self.model_type}")
|
||||
|
||||
def model_reset(self):
|
||||
# cache for audio and feat
|
||||
self.remained_wav = None
|
||||
self.cached_feat = None
|
||||
|
||||
def output_reset(self):
|
||||
## outputs
|
||||
# partial/ending decoding results
|
||||
self.result_transcripts = ['']
|
||||
|
||||
def reset_continuous_decoding(self):
|
||||
"""
|
||||
when in continous decoding, reset for next utterance.
|
||||
"""
|
||||
self.global_frame_offset = self.num_frames
|
||||
self.model_reset()
|
||||
|
||||
def reset(self):
|
||||
if "deepspeech2" in self.model_type:
|
||||
# for deepspeech2
|
||||
# init state
|
||||
self.chunk_state_h_box = np.zeros(
|
||||
(self.model_config.num_rnn_layers, 1,
|
||||
self.model_config.rnn_layer_size),
|
||||
dtype=float32)
|
||||
self.chunk_state_c_box = np.zeros(
|
||||
(self.model_config.num_rnn_layers, 1,
|
||||
self.model_config.rnn_layer_size),
|
||||
dtype=float32)
|
||||
self.decoder.reset_decoder(batch_size=1)
|
||||
else:
|
||||
raise NotImplementedError(f"{self.model_type} not support.")
|
||||
|
||||
self.device = None
|
||||
|
||||
## common
|
||||
# global sample and frame step
|
||||
self.num_samples = 0
|
||||
self.global_frame_offset = 0
|
||||
# frame step of cur utterance
|
||||
self.num_frames = 0
|
||||
|
||||
## endpoint
|
||||
self.endpoint_state = False # True for detect endpoint
|
||||
|
||||
## conformer
|
||||
self.model_reset()
|
||||
|
||||
## outputs
|
||||
self.output_reset()
|
||||
|
||||
def extract_feat(self, samples: ByteString):
|
||||
logger.info("Online ASR extract the feat")
|
||||
samples = np.frombuffer(samples, dtype=np.int16)
|
||||
assert samples.ndim == 1
|
||||
|
||||
self.num_samples += samples.shape[0]
|
||||
logger.info(
|
||||
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
|
||||
)
|
||||
|
||||
# self.reamined_wav stores all the samples,
|
||||
# include the original remained_wav and this package samples
|
||||
if self.remained_wav is None:
|
||||
self.remained_wav = samples
|
||||
else:
|
||||
assert self.remained_wav.ndim == 1 # (T,)
|
||||
self.remained_wav = np.concatenate([self.remained_wav, samples])
|
||||
logger.info(
|
||||
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
|
||||
)
|
||||
|
||||
if len(self.remained_wav) < self.win_length:
|
||||
# samples not enough for feature window
|
||||
return 0
|
||||
|
||||
# fbank
|
||||
x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args)
|
||||
x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0)
|
||||
|
||||
# feature cache
|
||||
if self.cached_feat is None:
|
||||
self.cached_feat = x_chunk
|
||||
else:
|
||||
assert (len(x_chunk.shape) == 3) # (B,T,D)
|
||||
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
|
||||
self.cached_feat = paddle.concat(
|
||||
[self.cached_feat, x_chunk], axis=1)
|
||||
|
||||
# set the feat device
|
||||
if self.device is None:
|
||||
self.device = self.cached_feat.place
|
||||
|
||||
# cur frame step
|
||||
num_frames = x_chunk.shape[1]
|
||||
|
||||
# global frame step
|
||||
self.num_frames += num_frames
|
||||
|
||||
# update remained wav
|
||||
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
|
||||
|
||||
logger.info(
|
||||
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
|
||||
)
|
||||
logger.info(
|
||||
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
|
||||
)
|
||||
logger.info(f"global samples: {self.num_samples}")
|
||||
logger.info(f"global frames: {self.num_frames}")
|
||||
|
||||
def decode(self, is_finished=False):
|
||||
"""advance decoding
|
||||
|
||||
Args:
|
||||
is_finished (bool, optional): Is last frame or not. Defaults to False.
|
||||
|
||||
Returns:
|
||||
None:
|
||||
"""
|
||||
if "deepspeech2" in self.model_type:
|
||||
decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
|
||||
|
||||
context = 7 # context=7, in audio frame unit
|
||||
subsampling = 4 # subsampling=4, in audio frame unit
|
||||
|
||||
cached_feature_num = context - subsampling
|
||||
# decoding window for model, in audio frame unit
|
||||
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
||||
# decoding stride for model, in audio frame unit
|
||||
stride = subsampling * decoding_chunk_size
|
||||
|
||||
if self.cached_feat is None:
|
||||
logger.info("no audio feat, please input more pcm data")
|
||||
return
|
||||
|
||||
num_frames = self.cached_feat.shape[1]
|
||||
logger.info(
|
||||
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
|
||||
)
|
||||
|
||||
# the cached feat must be larger decoding_window
|
||||
if num_frames < decoding_window and not is_finished:
|
||||
logger.info(
|
||||
f"frame feat num is less than {decoding_window}, please input more pcm data"
|
||||
)
|
||||
return None, None
|
||||
|
||||
# if is_finished=True, we need at least context frames
|
||||
if num_frames < context:
|
||||
logger.info(
|
||||
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
|
||||
)
|
||||
return None, None
|
||||
|
||||
logger.info("start to do model forward")
|
||||
# num_frames - context + 1 ensure that current frame can get context window
|
||||
if is_finished:
|
||||
# if get the finished chunk, we need process the last context
|
||||
left_frames = context
|
||||
else:
|
||||
# we only process decoding_window frames for one chunk
|
||||
left_frames = decoding_window
|
||||
|
||||
end = None
|
||||
for cur in range(0, num_frames - left_frames + 1, stride):
|
||||
end = min(cur + decoding_window, num_frames)
|
||||
|
||||
# extract the audio
|
||||
x_chunk = self.cached_feat[:, cur:end, :].numpy()
|
||||
x_chunk_lens = np.array([x_chunk.shape[1]])
|
||||
|
||||
trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)
|
||||
|
||||
self.result_transcripts = [trans_best]
|
||||
|
||||
# update feat cache
|
||||
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
|
||||
|
||||
# return trans_best[0]
|
||||
else:
|
||||
raise Exception(f"{self.model_type} not support paddleinference.")
|
||||
|
||||
@paddle.no_grad()
|
||||
def decode_one_chunk(self, x_chunk, x_chunk_lens):
|
||||
"""forward one chunk frames
|
||||
|
||||
Args:
|
||||
x_chunk (np.ndarray): (B,T,D), audio frames.
|
||||
x_chunk_lens ([type]): (B,), audio frame lens
|
||||
|
||||
Returns:
|
||||
logprob: poster probability.
|
||||
"""
|
||||
logger.info("start to decoce one chunk for deepspeech2")
|
||||
# state_c, state_h, audio_lens, audio
|
||||
# 'chunk_state_c_box', 'chunk_state_h_box', 'audio_chunk_lens', 'audio_chunk'
|
||||
input_names = [n.name for n in self.am_predictor.get_inputs()]
|
||||
logger.info(f"ort inputs: {input_names}")
|
||||
# 'softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'
|
||||
# audio, audio_lens, state_h, state_c
|
||||
output_names = [n.name for n in self.am_predictor.get_outputs()]
|
||||
logger.info(f"ort outpus: {output_names}")
|
||||
assert (len(input_names) == len(output_names))
|
||||
assert isinstance(input_names[0], str)
|
||||
|
||||
input_datas = [
|
||||
self.chunk_state_c_box, self.chunk_state_h_box, x_chunk_lens,
|
||||
x_chunk
|
||||
]
|
||||
feeds = dict(zip(input_names, input_datas))
|
||||
|
||||
outputs = self.am_predictor.run([*output_names], {**feeds})
|
||||
|
||||
output_chunk_probs, output_chunk_lens, self.chunk_state_h_box, self.chunk_state_c_box = outputs
|
||||
self.decoder.next(output_chunk_probs, output_chunk_lens)
|
||||
trans_best, trans_beam = self.decoder.decode()
|
||||
logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
|
||||
return trans_best[0]
|
||||
|
||||
def get_result(self):
|
||||
"""return partial/ending asr result.
|
||||
|
||||
Returns:
|
||||
str: one best result of partial/ending.
|
||||
"""
|
||||
if len(self.result_transcripts) > 0:
|
||||
return self.result_transcripts[0]
|
||||
else:
|
||||
return ''
|
||||
|
||||
def get_word_time_stamp(self):
|
||||
return []
|
||||
|
||||
@paddle.no_grad()
|
||||
def rescoring(self):
|
||||
...
|
||||
|
||||
|
||||
class ASRServerExecutor(ASRExecutor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.task_resource = CommonTaskResource(
|
||||
task='asr', model_format='onnx', inference_mode='online')
|
||||
|
||||
def update_config(self) -> None:
|
||||
if "deepspeech2" in self.model_type:
|
||||
with UpdateConfig(self.config):
|
||||
# download lm
|
||||
self.config.decode.lang_model_path = os.path.join(
|
||||
MODEL_HOME, 'language_model',
|
||||
self.config.decode.lang_model_path)
|
||||
|
||||
lm_url = self.task_resource.res_dict['lm_url']
|
||||
lm_md5 = self.task_resource.res_dict['lm_md5']
|
||||
logger.info(f"Start to load language model {lm_url}")
|
||||
self.download_lm(
|
||||
lm_url,
|
||||
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{self.model_type} not support paddleinference.")
|
||||
|
||||
def init_model(self) -> None:
|
||||
|
||||
if "deepspeech2" in self.model_type:
|
||||
# AM predictor
|
||||
logger.info("ASR engine start to init the am predictor")
|
||||
self.am_predictor = onnx_infer.get_sess(
|
||||
model_path=self.am_model, sess_conf=self.am_predictor_conf)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{self.model_type} not support paddleinference.")
|
||||
|
||||
def _init_from_path(self,
|
||||
model_type: str=None,
|
||||
am_model: Optional[os.PathLike]=None,
|
||||
am_params: Optional[os.PathLike]=None,
|
||||
lang: str='zh',
|
||||
sample_rate: int=16000,
|
||||
cfg_path: Optional[os.PathLike]=None,
|
||||
decode_method: str='attention_rescoring',
|
||||
num_decoding_left_chunks: int=-1,
|
||||
am_predictor_conf: dict=None):
|
||||
"""
|
||||
Init model and other resources from a specific path.
|
||||
"""
|
||||
if not model_type or not lang or not sample_rate:
|
||||
logger.error(
|
||||
"The model type or lang or sample rate is None, please input an valid server parameter yaml"
|
||||
)
|
||||
return False
|
||||
assert am_params is None, "am_params not used in onnx engine"
|
||||
|
||||
self.model_type = model_type
|
||||
self.sample_rate = sample_rate
|
||||
self.decode_method = decode_method
|
||||
self.num_decoding_left_chunks = num_decoding_left_chunks
|
||||
# conf for paddleinference predictor or onnx
|
||||
self.am_predictor_conf = am_predictor_conf
|
||||
logger.info(f"model_type: {self.model_type}")
|
||||
|
||||
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
||||
tag = model_type + '-' + lang + '-' + sample_rate_str
|
||||
self.task_resource.set_task_model(model_tag=tag)
|
||||
|
||||
if cfg_path is None:
|
||||
self.res_path = self.task_resource.res_dir
|
||||
self.cfg_path = os.path.join(
|
||||
self.res_path, self.task_resource.res_dict['cfg_path'])
|
||||
else:
|
||||
self.cfg_path = os.path.abspath(cfg_path)
|
||||
self.res_path = os.path.dirname(
|
||||
os.path.dirname(os.path.abspath(self.cfg_path)))
|
||||
|
||||
self.am_model = os.path.join(self.res_path, self.task_resource.res_dict[
|
||||
'onnx_model']) if am_model is None else os.path.abspath(am_model)
|
||||
|
||||
# self.am_params = os.path.join(
|
||||
# self.res_path, self.task_resource.res_dict[
|
||||
# 'params']) if am_params is None else os.path.abspath(am_params)
|
||||
|
||||
logger.info("Load the pretrained model:")
|
||||
logger.info(f" tag = {tag}")
|
||||
logger.info(f" res_path: {self.res_path}")
|
||||
logger.info(f" cfg path: {self.cfg_path}")
|
||||
logger.info(f" am_model path: {self.am_model}")
|
||||
# logger.info(f" am_params path: {self.am_params}")
|
||||
|
||||
#Init body.
|
||||
self.config = CfgNode(new_allowed=True)
|
||||
self.config.merge_from_file(self.cfg_path)
|
||||
|
||||
if self.config.spm_model_prefix:
|
||||
self.config.spm_model_prefix = os.path.join(
|
||||
self.res_path, self.config.spm_model_prefix)
|
||||
logger.info(f"spm model path: {self.config.spm_model_prefix}")
|
||||
|
||||
self.vocab = self.config.vocab_filepath
|
||||
|
||||
self.text_feature = TextFeaturizer(
|
||||
unit_type=self.config.unit_type,
|
||||
vocab=self.config.vocab_filepath,
|
||||
spm_model_prefix=self.config.spm_model_prefix)
|
||||
|
||||
self.update_config()
|
||||
|
||||
# AM predictor
|
||||
self.init_model()
|
||||
|
||||
logger.info(f"create the {model_type} model success")
|
||||
return True
|
||||
|
||||
|
||||
class ASREngine(BaseEngine):
|
||||
"""ASR model resource
|
||||
|
||||
Args:
|
||||
metaclass: Defaults to Singleton.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ASREngine, self).__init__()
|
||||
|
||||
def init_model(self) -> bool:
|
||||
if not self.executor._init_from_path(
|
||||
model_type=self.config.model_type,
|
||||
am_model=self.config.am_model,
|
||||
am_params=self.config.am_params,
|
||||
lang=self.config.lang,
|
||||
sample_rate=self.config.sample_rate,
|
||||
cfg_path=self.config.cfg_path,
|
||||
decode_method=self.config.decode_method,
|
||||
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
|
||||
am_predictor_conf=self.config.am_predictor_conf):
|
||||
return False
|
||||
return True
|
||||
|
||||
def init(self, config: dict) -> bool:
|
||||
"""init engine resource
|
||||
|
||||
Args:
|
||||
config_file (str): config file
|
||||
|
||||
Returns:
|
||||
bool: init failed or success
|
||||
"""
|
||||
self.config = config
|
||||
self.executor = ASRServerExecutor()
|
||||
|
||||
try:
|
||||
self.device = self.config.get("device", paddle.get_device())
|
||||
paddle.set_device(self.device)
|
||||
except BaseException as e:
|
||||
logger.error(
|
||||
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
|
||||
)
|
||||
logger.error(
|
||||
"If all GPU or XPU is used, you can set the server to 'cpu'")
|
||||
sys.exit(-1)
|
||||
|
||||
logger.info(f"paddlespeech_server set the device: {self.device}")
|
||||
|
||||
if not self.init_model():
|
||||
logger.error(
|
||||
"Init the ASR server occurs error, please check the server configuration yaml"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info("Initialize ASR server engine successfully.")
|
||||
return True
|
||||
|
||||
def new_handler(self):
|
||||
"""New handler from model.
|
||||
|
||||
Returns:
|
||||
PaddleASRConnectionHanddler: asr handler instance
|
||||
"""
|
||||
return PaddleASRConnectionHanddler(self)
|
||||
|
||||
def preprocess(self, *args, **kwargs):
|
||||
raise NotImplementedError("Online not using this.")
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
raise NotImplementedError("Online not using this.")
|
||||
|
||||
def postprocess(self):
|
||||
raise NotImplementedError("Online not using this.")
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,545 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
from typing import ByteString
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from numpy import float32
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.cli.asr.infer import ASRExecutor
|
||||
from paddlespeech.cli.log import logger
|
||||
from paddlespeech.cli.utils import MODEL_HOME
|
||||
from paddlespeech.resource import CommonTaskResource
|
||||
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||
from paddlespeech.s2t.modules.ctc import CTCDecoder
|
||||
from paddlespeech.s2t.transform.transformation import Transformation
|
||||
from paddlespeech.s2t.utils.utility import UpdateConfig
|
||||
from paddlespeech.server.engine.base_engine import BaseEngine
|
||||
from paddlespeech.server.utils.paddle_predictor import init_predictor
|
||||
|
||||
__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine']
|
||||
|
||||
|
||||
# ASR server connection process class
|
||||
class PaddleASRConnectionHanddler:
|
||||
def __init__(self, asr_engine):
|
||||
"""Init a Paddle ASR Connection Handler instance
|
||||
|
||||
Args:
|
||||
asr_engine (ASREngine): the global asr engine
|
||||
"""
|
||||
super().__init__()
|
||||
logger.info(
|
||||
"create an paddle asr connection handler to process the websocket connection"
|
||||
)
|
||||
self.config = asr_engine.config # server config
|
||||
self.model_config = asr_engine.executor.config
|
||||
self.asr_engine = asr_engine
|
||||
|
||||
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
|
||||
self.model_type = self.asr_engine.executor.model_type
|
||||
self.sample_rate = self.asr_engine.executor.sample_rate
|
||||
# tokens to text
|
||||
self.text_feature = self.asr_engine.executor.text_feature
|
||||
|
||||
# extract feat, new only fbank in conformer model
|
||||
self.preprocess_conf = self.model_config.preprocess_config
|
||||
self.preprocess_args = {"train": False}
|
||||
self.preprocessing = Transformation(self.preprocess_conf)
|
||||
|
||||
# frame window and frame shift, in samples unit
|
||||
self.win_length = self.preprocess_conf.process[0]['win_length']
|
||||
self.n_shift = self.preprocess_conf.process[0]['n_shift']
|
||||
|
||||
assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, (
|
||||
self.sample_rate, self.preprocess_conf.process[0]['fs'])
|
||||
self.frame_shift_in_ms = int(
|
||||
self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000)
|
||||
|
||||
self.continuous_decoding = self.config.get("continuous_decoding", False)
|
||||
self.init_decoder()
|
||||
self.reset()
|
||||
|
||||
def init_decoder(self):
|
||||
if "deepspeech2" in self.model_type:
|
||||
assert self.continuous_decoding is False, "ds2 model not support endpoint"
|
||||
self.am_predictor = self.asr_engine.executor.am_predictor
|
||||
|
||||
self.decoder = CTCDecoder(
|
||||
odim=self.model_config.output_dim, # <blank> is in vocab
|
||||
enc_n_units=self.model_config.rnn_layer_size * 2,
|
||||
blank_id=self.model_config.blank_id,
|
||||
dropout_rate=0.0,
|
||||
reduction=True, # sum
|
||||
batch_average=True, # sum / batch_size
|
||||
grad_norm_type=self.model_config.get('ctc_grad_norm_type',
|
||||
None))
|
||||
|
||||
cfg = self.model_config.decode
|
||||
decode_batch_size = 1 # for online
|
||||
self.decoder.init_decoder(
|
||||
decode_batch_size, self.text_feature.vocab_list,
|
||||
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
|
||||
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
|
||||
cfg.num_proc_bsearch)
|
||||
else:
|
||||
raise ValueError(f"Not supported: {self.model_type}")
|
||||
|
||||
def model_reset(self):
|
||||
# cache for audio and feat
|
||||
self.remained_wav = None
|
||||
self.cached_feat = None
|
||||
|
||||
def output_reset(self):
|
||||
## outputs
|
||||
# partial/ending decoding results
|
||||
self.result_transcripts = ['']
|
||||
|
||||
def reset_continuous_decoding(self):
|
||||
"""
|
||||
when in continous decoding, reset for next utterance.
|
||||
"""
|
||||
self.global_frame_offset = self.num_frames
|
||||
self.model_reset()
|
||||
|
||||
def reset(self):
|
||||
if "deepspeech2" in self.model_type:
|
||||
# for deepspeech2
|
||||
# init state
|
||||
self.chunk_state_h_box = np.zeros(
|
||||
(self.model_config.num_rnn_layers, 1,
|
||||
self.model_config.rnn_layer_size),
|
||||
dtype=float32)
|
||||
self.chunk_state_c_box = np.zeros(
|
||||
(self.model_config.num_rnn_layers, 1,
|
||||
self.model_config.rnn_layer_size),
|
||||
dtype=float32)
|
||||
self.decoder.reset_decoder(batch_size=1)
|
||||
else:
|
||||
raise NotImplementedError(f"{self.model_type} not support.")
|
||||
|
||||
self.device = None
|
||||
|
||||
## common
|
||||
# global sample and frame step
|
||||
self.num_samples = 0
|
||||
self.global_frame_offset = 0
|
||||
# frame step of cur utterance
|
||||
self.num_frames = 0
|
||||
|
||||
## endpoint
|
||||
self.endpoint_state = False # True for detect endpoint
|
||||
|
||||
## conformer
|
||||
self.model_reset()
|
||||
|
||||
## outputs
|
||||
self.output_reset()
|
||||
|
||||
def extract_feat(self, samples: ByteString):
|
||||
logger.info("Online ASR extract the feat")
|
||||
samples = np.frombuffer(samples, dtype=np.int16)
|
||||
assert samples.ndim == 1
|
||||
|
||||
self.num_samples += samples.shape[0]
|
||||
logger.info(
|
||||
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
|
||||
)
|
||||
|
||||
# self.reamined_wav stores all the samples,
|
||||
# include the original remained_wav and this package samples
|
||||
if self.remained_wav is None:
|
||||
self.remained_wav = samples
|
||||
else:
|
||||
assert self.remained_wav.ndim == 1 # (T,)
|
||||
self.remained_wav = np.concatenate([self.remained_wav, samples])
|
||||
logger.info(
|
||||
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
|
||||
)
|
||||
|
||||
if len(self.remained_wav) < self.win_length:
|
||||
# samples not enough for feature window
|
||||
return 0
|
||||
|
||||
# fbank
|
||||
x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args)
|
||||
x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0)
|
||||
|
||||
# feature cache
|
||||
if self.cached_feat is None:
|
||||
self.cached_feat = x_chunk
|
||||
else:
|
||||
assert (len(x_chunk.shape) == 3) # (B,T,D)
|
||||
assert (len(self.cached_feat.shape) == 3) # (B,T,D)
|
||||
self.cached_feat = paddle.concat(
|
||||
[self.cached_feat, x_chunk], axis=1)
|
||||
|
||||
# set the feat device
|
||||
if self.device is None:
|
||||
self.device = self.cached_feat.place
|
||||
|
||||
# cur frame step
|
||||
num_frames = x_chunk.shape[1]
|
||||
|
||||
# global frame step
|
||||
self.num_frames += num_frames
|
||||
|
||||
# update remained wav
|
||||
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
|
||||
|
||||
logger.info(
|
||||
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
|
||||
)
|
||||
logger.info(
|
||||
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
|
||||
)
|
||||
logger.info(f"global samples: {self.num_samples}")
|
||||
logger.info(f"global frames: {self.num_frames}")
|
||||
|
||||
def decode(self, is_finished=False):
|
||||
"""advance decoding
|
||||
|
||||
Args:
|
||||
is_finished (bool, optional): Is last frame or not. Defaults to False.
|
||||
|
||||
Returns:
|
||||
None:
|
||||
"""
|
||||
if "deepspeech2" in self.model_type:
|
||||
decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
|
||||
|
||||
context = 7 # context=7, in audio frame unit
|
||||
subsampling = 4 # subsampling=4, in audio frame unit
|
||||
|
||||
cached_feature_num = context - subsampling
|
||||
# decoding window for model, in audio frame unit
|
||||
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
||||
# decoding stride for model, in audio frame unit
|
||||
stride = subsampling * decoding_chunk_size
|
||||
|
||||
if self.cached_feat is None:
|
||||
logger.info("no audio feat, please input more pcm data")
|
||||
return
|
||||
|
||||
num_frames = self.cached_feat.shape[1]
|
||||
logger.info(
|
||||
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
|
||||
)
|
||||
|
||||
# the cached feat must be larger decoding_window
|
||||
if num_frames < decoding_window and not is_finished:
|
||||
logger.info(
|
||||
f"frame feat num is less than {decoding_window}, please input more pcm data"
|
||||
)
|
||||
return None, None
|
||||
|
||||
# if is_finished=True, we need at least context frames
|
||||
if num_frames < context:
|
||||
logger.info(
|
||||
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
|
||||
)
|
||||
return None, None
|
||||
|
||||
logger.info("start to do model forward")
|
||||
# num_frames - context + 1 ensure that current frame can get context window
|
||||
if is_finished:
|
||||
# if get the finished chunk, we need process the last context
|
||||
left_frames = context
|
||||
else:
|
||||
# we only process decoding_window frames for one chunk
|
||||
left_frames = decoding_window
|
||||
|
||||
end = None
|
||||
for cur in range(0, num_frames - left_frames + 1, stride):
|
||||
end = min(cur + decoding_window, num_frames)
|
||||
|
||||
# extract the audio
|
||||
x_chunk = self.cached_feat[:, cur:end, :].numpy()
|
||||
x_chunk_lens = np.array([x_chunk.shape[1]])
|
||||
|
||||
trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)
|
||||
|
||||
self.result_transcripts = [trans_best]
|
||||
|
||||
# update feat cache
|
||||
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
|
||||
|
||||
# return trans_best[0]
|
||||
else:
|
||||
raise Exception(f"{self.model_type} not support paddleinference.")
|
||||
|
||||
@paddle.no_grad()
|
||||
def decode_one_chunk(self, x_chunk, x_chunk_lens):
|
||||
"""forward one chunk frames
|
||||
|
||||
Args:
|
||||
x_chunk (np.ndarray): (B,T,D), audio frames.
|
||||
x_chunk_lens ([type]): (B,), audio frame lens
|
||||
|
||||
Returns:
|
||||
logprob: poster probability.
|
||||
"""
|
||||
logger.info("start to decoce one chunk for deepspeech2")
|
||||
input_names = self.am_predictor.get_input_names()
|
||||
audio_handle = self.am_predictor.get_input_handle(input_names[0])
|
||||
audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
|
||||
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
|
||||
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
|
||||
|
||||
audio_handle.reshape(x_chunk.shape)
|
||||
audio_handle.copy_from_cpu(x_chunk)
|
||||
|
||||
audio_len_handle.reshape(x_chunk_lens.shape)
|
||||
audio_len_handle.copy_from_cpu(x_chunk_lens)
|
||||
|
||||
h_box_handle.reshape(self.chunk_state_h_box.shape)
|
||||
h_box_handle.copy_from_cpu(self.chunk_state_h_box)
|
||||
|
||||
c_box_handle.reshape(self.chunk_state_c_box.shape)
|
||||
c_box_handle.copy_from_cpu(self.chunk_state_c_box)
|
||||
|
||||
output_names = self.am_predictor.get_output_names()
|
||||
output_handle = self.am_predictor.get_output_handle(output_names[0])
|
||||
output_lens_handle = self.am_predictor.get_output_handle(
|
||||
output_names[1])
|
||||
output_state_h_handle = self.am_predictor.get_output_handle(
|
||||
output_names[2])
|
||||
output_state_c_handle = self.am_predictor.get_output_handle(
|
||||
output_names[3])
|
||||
|
||||
self.am_predictor.run()
|
||||
|
||||
output_chunk_probs = output_handle.copy_to_cpu()
|
||||
output_chunk_lens = output_lens_handle.copy_to_cpu()
|
||||
self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
|
||||
self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()
|
||||
|
||||
self.decoder.next(output_chunk_probs, output_chunk_lens)
|
||||
trans_best, trans_beam = self.decoder.decode()
|
||||
logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
|
||||
return trans_best[0]
|
||||
|
||||
def get_result(self):
|
||||
"""return partial/ending asr result.
|
||||
|
||||
Returns:
|
||||
str: one best result of partial/ending.
|
||||
"""
|
||||
if len(self.result_transcripts) > 0:
|
||||
return self.result_transcripts[0]
|
||||
else:
|
||||
return ''
|
||||
|
||||
def get_word_time_stamp(self):
|
||||
return []
|
||||
|
||||
@paddle.no_grad()
|
||||
def rescoring(self):
|
||||
...
|
||||
|
||||
|
||||
class ASRServerExecutor(ASRExecutor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.task_resource = CommonTaskResource(
|
||||
task='asr', model_format='static', inference_mode='online')
|
||||
|
||||
def update_config(self) -> None:
|
||||
if "deepspeech2" in self.model_type:
|
||||
with UpdateConfig(self.config):
|
||||
# download lm
|
||||
self.config.decode.lang_model_path = os.path.join(
|
||||
MODEL_HOME, 'language_model',
|
||||
self.config.decode.lang_model_path)
|
||||
|
||||
lm_url = self.task_resource.res_dict['lm_url']
|
||||
lm_md5 = self.task_resource.res_dict['lm_md5']
|
||||
logger.info(f"Start to load language model {lm_url}")
|
||||
self.download_lm(
|
||||
lm_url,
|
||||
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{self.model_type} not support paddleinference.")
|
||||
|
||||
def init_model(self) -> None:
|
||||
|
||||
if "deepspeech2" in self.model_type:
|
||||
# AM predictor
|
||||
logger.info("ASR engine start to init the am predictor")
|
||||
self.am_predictor = init_predictor(
|
||||
model_file=self.am_model,
|
||||
params_file=self.am_params,
|
||||
predictor_conf=self.am_predictor_conf)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{self.model_type} not support paddleinference.")
|
||||
|
||||
def _init_from_path(self,
|
||||
model_type: str=None,
|
||||
am_model: Optional[os.PathLike]=None,
|
||||
am_params: Optional[os.PathLike]=None,
|
||||
lang: str='zh',
|
||||
sample_rate: int=16000,
|
||||
cfg_path: Optional[os.PathLike]=None,
|
||||
decode_method: str='attention_rescoring',
|
||||
num_decoding_left_chunks: int=-1,
|
||||
am_predictor_conf: dict=None):
|
||||
"""
|
||||
Init model and other resources from a specific path.
|
||||
"""
|
||||
if not model_type or not lang or not sample_rate:
|
||||
logger.error(
|
||||
"The model type or lang or sample rate is None, please input an valid server parameter yaml"
|
||||
)
|
||||
return False
|
||||
|
||||
self.model_type = model_type
|
||||
self.sample_rate = sample_rate
|
||||
self.decode_method = decode_method
|
||||
self.num_decoding_left_chunks = num_decoding_left_chunks
|
||||
# conf for paddleinference predictor or onnx
|
||||
self.am_predictor_conf = am_predictor_conf
|
||||
logger.info(f"model_type: {self.model_type}")
|
||||
|
||||
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
||||
tag = model_type + '-' + lang + '-' + sample_rate_str
|
||||
self.task_resource.set_task_model(model_tag=tag)
|
||||
|
||||
if cfg_path is None or am_model is None or am_params is None:
|
||||
self.res_path = self.task_resource.res_dir
|
||||
self.cfg_path = os.path.join(
|
||||
self.res_path, self.task_resource.res_dict['cfg_path'])
|
||||
|
||||
self.am_model = os.path.join(self.res_path,
|
||||
self.task_resource.res_dict['model'])
|
||||
self.am_params = os.path.join(self.res_path,
|
||||
self.task_resource.res_dict['params'])
|
||||
else:
|
||||
self.cfg_path = os.path.abspath(cfg_path)
|
||||
self.am_model = os.path.abspath(am_model)
|
||||
self.am_params = os.path.abspath(am_params)
|
||||
self.res_path = os.path.dirname(
|
||||
os.path.dirname(os.path.abspath(self.cfg_path)))
|
||||
|
||||
logger.info("Load the pretrained model:")
|
||||
logger.info(f" tag = {tag}")
|
||||
logger.info(f" res_path: {self.res_path}")
|
||||
logger.info(f" cfg path: {self.cfg_path}")
|
||||
logger.info(f" am_model path: {self.am_model}")
|
||||
logger.info(f" am_params path: {self.am_params}")
|
||||
|
||||
#Init body.
|
||||
self.config = CfgNode(new_allowed=True)
|
||||
self.config.merge_from_file(self.cfg_path)
|
||||
|
||||
if self.config.spm_model_prefix:
|
||||
self.config.spm_model_prefix = os.path.join(
|
||||
self.res_path, self.config.spm_model_prefix)
|
||||
logger.info(f"spm model path: {self.config.spm_model_prefix}")
|
||||
|
||||
self.vocab = self.config.vocab_filepath
|
||||
|
||||
self.text_feature = TextFeaturizer(
|
||||
unit_type=self.config.unit_type,
|
||||
vocab=self.config.vocab_filepath,
|
||||
spm_model_prefix=self.config.spm_model_prefix)
|
||||
|
||||
self.update_config()
|
||||
|
||||
# AM predictor
|
||||
self.init_model()
|
||||
|
||||
logger.info(f"create the {model_type} model success")
|
||||
return True
|
||||
|
||||
|
||||
class ASREngine(BaseEngine):
|
||||
"""ASR model resource
|
||||
|
||||
Args:
|
||||
metaclass: Defaults to Singleton.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ASREngine, self).__init__()
|
||||
|
||||
def init_model(self) -> bool:
|
||||
if not self.executor._init_from_path(
|
||||
model_type=self.config.model_type,
|
||||
am_model=self.config.am_model,
|
||||
am_params=self.config.am_params,
|
||||
lang=self.config.lang,
|
||||
sample_rate=self.config.sample_rate,
|
||||
cfg_path=self.config.cfg_path,
|
||||
decode_method=self.config.decode_method,
|
||||
num_decoding_left_chunks=self.config.num_decoding_left_chunks,
|
||||
am_predictor_conf=self.config.am_predictor_conf):
|
||||
return False
|
||||
return True
|
||||
|
||||
def init(self, config: dict) -> bool:
|
||||
"""init engine resource
|
||||
|
||||
Args:
|
||||
config_file (str): config file
|
||||
|
||||
Returns:
|
||||
bool: init failed or success
|
||||
"""
|
||||
self.config = config
|
||||
self.executor = ASRServerExecutor()
|
||||
|
||||
try:
|
||||
self.device = self.config.get("device", paddle.get_device())
|
||||
paddle.set_device(self.device)
|
||||
except BaseException as e:
|
||||
logger.error(
|
||||
f"Set device failed, please check if device '{self.device}' is already used and the parameter 'device' in the yaml file"
|
||||
)
|
||||
logger.error(
|
||||
"If all GPU or XPU is used, you can set the server to 'cpu'")
|
||||
sys.exit(-1)
|
||||
|
||||
logger.info(f"paddlespeech_server set the device: {self.device}")
|
||||
|
||||
if not self.init_model():
|
||||
logger.error(
|
||||
"Init the ASR server occurs error, please check the server configuration yaml"
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info("Initialize ASR server engine successfully.")
|
||||
return True
|
||||
|
||||
def new_handler(self):
|
||||
"""New handler from model.
|
||||
|
||||
Returns:
|
||||
PaddleASRConnectionHanddler: asr handler instance
|
||||
"""
|
||||
return PaddleASRConnectionHanddler(self)
|
||||
|
||||
def preprocess(self, *args, **kwargs):
|
||||
raise NotImplementedError("Online not using this.")
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
raise NotImplementedError("Online not using this.")
|
||||
|
||||
def postprocess(self):
|
||||
raise NotImplementedError("Online not using this.")
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,3 @@
|
||||
data
|
||||
log
|
||||
exp
|
@ -0,0 +1,37 @@
|
||||
# DeepSpeech2 ONNX model
|
||||
|
||||
1. convert deepspeech2 model to ONNX, using Paddle2ONNX.
|
||||
2. check paddleinference and onnxruntime output equal.
|
||||
3. optimize onnx model
|
||||
4. check paddleinference and optimized onnxruntime output equal.
|
||||
|
||||
Please make sure [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX) and [onnx-simplifier](https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape) version is correct.
|
||||
|
||||
The example test with these packages installed:
|
||||
```
|
||||
paddle2onnx 0.9.8rc0 # develop af4354b4e9a61a93be6490640059a02a4499bc7a
|
||||
paddleaudio 0.2.1
|
||||
paddlefsl 1.1.0
|
||||
paddlenlp 2.2.6
|
||||
paddlepaddle-gpu 2.2.2
|
||||
paddlespeech 0.0.0 # develop
|
||||
paddlespeech-ctcdecoders 0.2.0
|
||||
paddlespeech-feat 0.1.0
|
||||
onnx 1.11.0
|
||||
onnx-simplifier 0.0.0 # https://github.com/zh794390558/onnx-simplifier/tree/dyn_time_shape
|
||||
onnxoptimizer 0.2.7
|
||||
onnxruntime 1.11.0
|
||||
```
|
||||
|
||||
## Using
|
||||
|
||||
```
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
For more details please see `run.sh`.
|
||||
|
||||
## Outputs
|
||||
The optimized onnx model is `exp/model.opt.onnx`.
|
||||
|
||||
To show the graph, please using `local/netron.sh`.
|
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
import paddle
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
'--input_file',
|
||||
type=str,
|
||||
default="static_ds2online_inputs.pickle",
|
||||
help="aishell ds2 input data file. For wenetspeech, we only feed for infer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model_type',
|
||||
type=str,
|
||||
default="aishell",
|
||||
help="aishell(1024) or wenetspeech(2048)", )
|
||||
parser.add_argument(
|
||||
'--model_dir', type=str, default=".", help="paddle model dir.")
|
||||
parser.add_argument(
|
||||
'--model_prefix',
|
||||
type=str,
|
||||
default="avg_1.jit",
|
||||
help="paddle model prefix.")
|
||||
parser.add_argument(
|
||||
'--onnx_model',
|
||||
type=str,
|
||||
default='./model.old.onnx',
|
||||
help="onnx model.")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
FLAGS = parse_args()
|
||||
|
||||
# input and output
|
||||
with open(FLAGS.input_file, 'rb') as f:
|
||||
iodict = pickle.load(f)
|
||||
print(iodict.keys())
|
||||
|
||||
audio_chunk = iodict['audio_chunk']
|
||||
audio_chunk_lens = iodict['audio_chunk_lens']
|
||||
chunk_state_h_box = iodict['chunk_state_h_box']
|
||||
chunk_state_c_box = iodict['chunk_state_c_bos']
|
||||
print("raw state shape: ", chunk_state_c_box.shape)
|
||||
|
||||
if FLAGS.model_type == 'wenetspeech':
|
||||
chunk_state_h_box = np.repeat(chunk_state_h_box, 2, axis=-1)
|
||||
chunk_state_c_box = np.repeat(chunk_state_c_box, 2, axis=-1)
|
||||
print("state shape: ", chunk_state_c_box.shape)
|
||||
|
||||
# paddle
|
||||
model = paddle.jit.load(os.path.join(FLAGS.model_dir, FLAGS.model_prefix))
|
||||
res_chunk, res_lens, chunk_state_h, chunk_state_c = model(
|
||||
paddle.to_tensor(audio_chunk),
|
||||
paddle.to_tensor(audio_chunk_lens),
|
||||
paddle.to_tensor(chunk_state_h_box),
|
||||
paddle.to_tensor(chunk_state_c_box), )
|
||||
|
||||
# onnxruntime
|
||||
options = onnxruntime.SessionOptions()
|
||||
options.enable_profiling = True
|
||||
sess = onnxruntime.InferenceSession(FLAGS.onnx_model, sess_options=options)
|
||||
ort_res_chunk, ort_res_lens, ort_chunk_state_h, ort_chunk_state_c = sess.run(
|
||||
['softmax_0.tmp_0', 'tmp_5', 'concat_0.tmp_0', 'concat_1.tmp_0'], {
|
||||
"audio_chunk": audio_chunk,
|
||||
"audio_chunk_lens": audio_chunk_lens,
|
||||
"chunk_state_h_box": chunk_state_h_box,
|
||||
"chunk_state_c_box": chunk_state_c_box
|
||||
})
|
||||
|
||||
print(sess.end_profiling())
|
||||
|
||||
# assert paddle equal ort
|
||||
print(np.allclose(ort_res_chunk, res_chunk, atol=1e-6))
|
||||
print(np.allclose(ort_res_lens, res_lens, atol=1e-6))
|
||||
|
||||
if FLAGS.model_type == 'aishell':
|
||||
print(np.allclose(ort_chunk_state_h, chunk_state_h, atol=1e-6))
|
||||
print(np.allclose(ort_chunk_state_c, chunk_state_c, atol=1e-6))
|
@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
|
||||
# show model
|
||||
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: $0 model_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
file=$1
|
||||
|
||||
pip install netron
|
||||
netron -p 8082 --host $(hostname -i) $file
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue