Merge pull request #1704 from Honei/server

[asr][websocket] add asr conformer websocket server
pull/1725/head
Hui Zhang 3 years ago committed by GitHub
commit cf9a590fa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -40,7 +40,6 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor'] __all__ = ['ASRExecutor']
@cli_register( @cli_register(
name='paddlespeech.asr', description='Speech to text infer command.') name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor): class ASRExecutor(BaseExecutor):
@ -125,6 +124,7 @@ class ASRExecutor(BaseExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
logger.info("start to init the model")
if hasattr(self, 'model'): if hasattr(self, 'model'):
logger.info('Model had been initialized.') logger.info('Model had been initialized.')
return return
@ -140,13 +140,14 @@ class ASRExecutor(BaseExecutor):
res_path, res_path,
self.pretrained_models[tag]['ckpt_path'] + ".pdparams") self.pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
self.res_path = os.path.dirname( self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
@ -176,7 +177,6 @@ class ASRExecutor(BaseExecutor):
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
else: else:
raise Exception("wrong type") raise Exception("wrong type")
model_name = model_type[:model_type.rindex( model_name = model_type[:model_type.rindex(
@ -254,12 +254,14 @@ class ASRExecutor(BaseExecutor):
else: else:
raise Exception("wrong type") raise Exception("wrong type")
logger.info("audio feat process success")
@paddle.no_grad() @paddle.no_grad()
def infer(self, model_type: str): def infer(self, model_type: str):
""" """
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
logger.info("start to infer the model to get the output")
cfg = self.config.decode cfg = self.config.decode
audio = self._inputs["audio"] audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"] audio_len = self._inputs["audio_len"]
@ -276,6 +278,8 @@ class ASRExecutor(BaseExecutor):
self._outputs["result"] = result_transcripts[0] self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
logger.info(f"we will use the transformer like model : {model_type}")
try:
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
audio_len, audio_len,
@ -287,6 +291,9 @@ class ASRExecutor(BaseExecutor):
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming) simulate_streaming=cfg.simulate_streaming)
self._outputs["result"] = result_transcripts[0][0] self._outputs["result"] = result_transcripts[0][0]
except Exception as e:
logger.exception(e)
else: else:
raise Exception("invalid model name") raise Exception("invalid model name")

@ -88,6 +88,8 @@ model_alias = {
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
"conformer": "conformer":
"paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.u2:U2Model",
"conformer_online":
"paddlespeech.s2t.models.u2:U2Model",
"transformer": "transformer":
"paddlespeech.s2t.models.u2:U2Model", "paddlespeech.s2t.models.u2:U2Model",
"wenetspeech": "wenetspeech":

@ -286,7 +286,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
# logp: (B*N, vocab) # logp: (B*N, vocab)
logp, cache = self.decoder.forward_one_step( logp, cache = self.decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache) encoder_out, encoder_mask, hyps, hyps_mask, cache)
# 2.2 First beam prune: select topk best prob at current time # 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag) top_k_logp = mask_finished_scores(top_k_logp, end_flag)
@ -708,11 +707,11 @@ class U2BaseModel(ASRInterface, nn.Layer):
batch_size = feats.shape[0] batch_size = feats.shape[0]
if decoding_method in ['ctc_prefix_beam_search', if decoding_method in ['ctc_prefix_beam_search',
'attention_rescoring'] and batch_size > 1: 'attention_rescoring'] and batch_size > 1:
logger.fatal( logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1' f'decoding mode {decoding_method} must be running with batch_size == 1'
) )
logger.error(f"current batch_size is {batch_size}")
sys.exit(1) sys.exit(1)
if decoding_method == 'attention': if decoding_method == 'attention':
hyps = self.recognize( hyps = self.recognize(
feats, feats,

@ -35,3 +35,16 @@
```bash ```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
``` ```
## Online ASR Server
### Lanuch online asr server
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### Access online asr server
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
```

@ -35,3 +35,17 @@
```bash ```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
``` ```
## 流式ASR
### 启动流式语音识别服务
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### 访问流式语音识别服务
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav
```

@ -277,11 +277,12 @@ class ASRClientExecutor(BaseExecutor):
lang=lang, lang=lang,
audio_format=audio_format) audio_format=audio_format)
time_end = time.time() time_end = time.time()
logger.info(res.json()) logger.info(res)
logger.info("Response time %f s." % (time_end - time_start)) logger.info("Response time %f s." % (time_end - time_start))
return True return True
except Exception as e: except Exception as e:
logger.error("Failed to speech recognition.") logger.error("Failed to speech recognition.")
logger.error(e)
return False return False
@stats_wrapper @stats_wrapper
@ -299,9 +300,10 @@ class ASRClientExecutor(BaseExecutor):
logging.info("asr websocket client start") logging.info("asr websocket client start")
handler = ASRAudioHandler(server_ip, port) handler = ASRAudioHandler(server_ip, port)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(input)) res = loop.run_until_complete(handler.run(input))
logging.info("asr websocket client finished") logging.info("asr websocket client finished")
return res['asr_results']
@cli_client_register( @cli_client_register(
name='paddlespeech_client.cls', description='visit cls service') name='paddlespeech_client.cls', description='visit cls service')

@ -41,11 +41,7 @@ asr_online:
shift_ms: 40 shift_ms: 40
sample_rate: 16000 sample_rate: 16000
sample_width: 2 sample_width: 2
window_n: 7 # frame
vad_conf: shift_n: 4 # frame
aggressiveness: 2 window_ms: 20 # ms
sample_rate: 16000 shift_ms: 10 # ms
frame_duration_ms: 20
sample_width: 2
padding_ms: 200
padding_ratio: 0.9

@ -0,0 +1,45 @@
# This is the parameter configuration file for PaddleSpeech Serving.
#################################################################################
# SERVER SETTING #
#################################################################################
host: 0.0.0.0
port: 8090
# The task format in the engin_list is: <speech task>_<engine type>
# task choices = ['asr_online', 'tts_online']
# protocol = ['websocket', 'http'] (only one can be selected).
# websocket only support online engine type.
protocol: 'websocket'
engine_list: ['asr_online']
#################################################################################
# ENGINE CONFIG #
#################################################################################
################################### ASR #########################################
################### speech task: asr; engine_type: online #######################
asr_online:
model_type: 'conformer_online_multicn'
am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional]
lang: 'zh'
sample_rate: 16000
cfg_path:
decode_method:
force_yes: True
am_predictor_conf:
device: # set 'gpu:id' or 'cpu'
switch_ir_optim: True
glog_info: False # True -> print glog
summary: True # False -> do not show predictor config
chunk_buffer_conf:
window_n: 7 # frame
shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000
sample_width: 2

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import os import os
from typing import Optional from typing import Optional
@ -20,12 +21,19 @@ from numpy import float32
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.asr.infer import model_alias
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
from paddlespeech.s2t.utils.tensor_utils import pad_sequence
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.audio_process import pcm2float
from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import init_predictor
@ -35,9 +43,9 @@ __all__ = ['ASREngine']
pretrained_models = { pretrained_models = {
"deepspeech2online_aishell-zh-16k": { "deepspeech2online_aishell-zh-16k": {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz', 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz',
'md5': 'md5':
'd5e076217cf60486519f72c217d21b9b', '23e16c69730a1cb5d735c98c83c21e16',
'cfg_path': 'cfg_path':
'model.yaml', 'model.yaml',
'ckpt_path': 'ckpt_path':
@ -51,16 +59,543 @@ pretrained_models = {
'lm_md5': 'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3' '29e02312deb2e59b3c8686c7966d4fe3'
}, },
"conformer_online_multicn-zh-16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz',
'md5':
'0ac93d390552336f2a906aec9e33c5fa',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer/checkpoints/multi_cn',
'model':
'exp/chunk_conformer/checkpoints/multi_cn.pdparams',
'params':
'exp/chunk_conformer/checkpoints/multi_cn.pdparams',
'lm_url':
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
'lm_md5':
'29e02312deb2e59b3c8686c7966d4fe3'
},
} }
# ASR server connection process class
class PaddleASRConnectionHanddler:
def __init__(self, asr_engine):
"""Init a Paddle ASR Connection Handler instance
Args:
asr_engine (ASREngine): the global asr engine
"""
super().__init__()
logger.info(
"create an paddle asr connection handler to process the websocket connection"
)
self.config = asr_engine.config
self.model_config = asr_engine.executor.config
self.asr_engine = asr_engine
self.init()
self.reset()
def init(self):
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
self.model_type = self.asr_engine.executor.model_type
self.sample_rate = self.asr_engine.executor.sample_rate
# tokens to text
self.text_feature = self.asr_engine.executor.text_feature
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
from paddlespeech.s2t.io.collator import SpeechCollator
self.am_predictor = self.asr_engine.executor.am_predictor
self.collate_fn_test = SpeechCollator.from_config(self.model_config)
self.decoder = CTCDecoder(
odim=self.model_config.output_dim, # <blank> is in vocab
enc_n_units=self.model_config.rnn_layer_size * 2,
blank_id=self.model_config.blank_id,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True, # sum / batch_size
grad_norm_type=self.model_config.get('ctc_grad_norm_type',
None))
cfg = self.model_config.decode
decode_batch_size = 1 # for online
self.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
# frame window samples length and frame shift samples length
self.win_length = int(self.model_config.window_ms / 1000 *
self.sample_rate)
self.n_shift = int(self.model_config.stride_ms / 1000 *
self.sample_rate)
elif "conformer" in self.model_type or "transformer" in self.model_type:
# acoustic model
self.model = self.asr_engine.executor.model
# ctc decoding config
self.ctc_decode_config = self.asr_engine.executor.config.decode
self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config)
# extract feat, new only fbank in conformer model
self.preprocess_conf = self.model_config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
# frame window samples length and frame shift samples length
self.win_length = self.preprocess_conf.process[0]['win_length']
self.n_shift = self.preprocess_conf.process[0]['n_shift']
def extract_feat(self, samples):
if "deepspeech2online" in self.model_type:
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
# pcm16 -> pcm 32
# pcm2float will change the orignal samples,
# so we shoule do pcm2float before concatenate
samples = pcm2float(samples)
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The connection remain the audio samples: {self.remained_wav.shape}"
)
# read audio
speech_segment = SpeechSegment.from_pcm(
self.remained_wav, self.sample_rate, transcript=" ")
# audio augment
self.collate_fn_test.augmentation.transform_audio(speech_segment)
# extract speech feature
spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize(
speech_segment, self.collate_fn_test.keep_transcription_text)
# CMVN spectrum
if self.collate_fn_test._normalizer:
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment
audio = self.collate_fn_test.augmentation.transform_feature(
spectrum)
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
# audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
if self.cached_feat is None:
self.cached_feat = audio
else:
assert (len(audio.shape) == 3)
assert (len(self.cached_feat.shape) == 3)
self.cached_feat = paddle.concat(
[self.cached_feat, audio], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
self.num_frames += audio_len
self.remained_wav = self.remained_wav[self.n_shift * audio_len:]
logger.info(
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
)
elif "conformer_online" in self.model_type:
logger.info("Online ASR extract the feat")
samples = np.frombuffer(samples, dtype=np.int16)
assert samples.ndim == 1
logger.info(f"This package receive {samples.shape[0]} pcm data")
self.num_samples += samples.shape[0]
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
if self.remained_wav is None:
self.remained_wav = samples
else:
assert self.remained_wav.ndim == 1
self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info(
f"The connection remain the audio samples: {self.remained_wav.shape}"
)
if len(self.remained_wav) < self.win_length:
return 0
# fbank
x_chunk = self.preprocessing(self.remained_wav,
**self.preprocess_args)
x_chunk = paddle.to_tensor(
x_chunk, dtype="float32").unsqueeze(axis=0)
if self.cached_feat is None:
self.cached_feat = x_chunk
else:
assert (len(x_chunk.shape) == 3)
assert (len(self.cached_feat.shape) == 3)
self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1)
# set the feat device
if self.device is None:
self.device = self.cached_feat.place
num_frames = x_chunk.shape[1]
self.num_frames += num_frames
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info(
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
)
logger.info(
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
)
# logger.info(f"accumulate samples: {self.num_samples}")
def reset(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
# for deepspeech2
self.chunk_state_h_box = copy.deepcopy(
self.asr_engine.executor.chunk_state_h_box)
self.chunk_state_c_box = copy.deepcopy(
self.asr_engine.executor.chunk_state_c_box)
self.decoder.reset_decoder(batch_size=1)
# for conformer online
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.encoder_out = None
self.cached_feat = None
self.remained_wav = None
self.offset = 0
self.num_samples = 0
self.device = None
self.hyps = []
self.num_frames = 0
self.chunk_num = 0
self.global_frame_offset = 0
self.result_transcripts = ['']
def decode(self, is_finished=False):
if "deepspeech2online" in self.model_type:
# x_chunk 是特征数据
decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model
context = 7 # context=7 in deepspeech2 model
subsampling = 4 # subsampling=4 in deepspeech2 model
stride = subsampling * decoding_chunk_size
cached_feature_num = context - subsampling
# decoding window for model
decoding_window = (decoding_chunk_size - 1) * subsampling + context
if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data")
return
num_frames = self.cached_feat.shape[1]
logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
)
# the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished:
logger.info(
f"frame feat num is less than {decoding_window}, please input more pcm data"
)
return None, None
# if is_finished=True, we need at least context frames
if num_frames < context:
logger.info(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return None, None
logger.info("start to do model forward")
# num_frames - context + 1 ensure that current frame can get context window
if is_finished:
# if get the finished chunk, we need process the last context
left_frames = context
else:
# we only process decoding_window frames for one chunk
left_frames = decoding_window
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
# extract the audio
x_chunk = self.cached_feat[:, cur:end, :].numpy()
x_chunk_lens = np.array([x_chunk.shape[1]])
trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)
self.result_transcripts = [trans_best]
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
# return trans_best[0]
elif "conformer" in self.model_type or "transformer" in self.model_type:
try:
logger.info(
f"we will use the transformer like model : {self.model_type}"
)
self.advance_decoding(is_finished)
self.update_result()
except Exception as e:
logger.exception(e)
else:
raise Exception("invalid model name")
@paddle.no_grad()
def decode_one_chunk(self, x_chunk, x_chunk_lens):
logger.info("start to decoce one chunk with deepspeech2 model")
input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0])
audio_len_handle = self.am_predictor.get_input_handle(input_names[1])
h_box_handle = self.am_predictor.get_input_handle(input_names[2])
c_box_handle = self.am_predictor.get_input_handle(input_names[3])
audio_handle.reshape(x_chunk.shape)
audio_handle.copy_from_cpu(x_chunk)
audio_len_handle.reshape(x_chunk_lens.shape)
audio_len_handle.copy_from_cpu(x_chunk_lens)
h_box_handle.reshape(self.chunk_state_h_box.shape)
h_box_handle.copy_from_cpu(self.chunk_state_h_box)
c_box_handle.reshape(self.chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(self.chunk_state_c_box)
output_names = self.am_predictor.get_output_names()
output_handle = self.am_predictor.get_output_handle(output_names[0])
output_lens_handle = self.am_predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.am_predictor.get_output_handle(
output_names[2])
output_state_c_handle = self.am_predictor.get_output_handle(
output_names[3])
self.am_predictor.run()
output_chunk_probs = output_handle.copy_to_cpu()
output_chunk_lens = output_lens_handle.copy_to_cpu()
self.chunk_state_h_box = output_state_h_handle.copy_to_cpu()
self.chunk_state_c_box = output_state_c_handle.copy_to_cpu()
self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one best result: {trans_best[0]}")
return trans_best[0]
@paddle.no_grad()
def advance_decoding(self, is_finished=False):
logger.info("start to decode with advanced_decoding method")
cfg = self.ctc_decode_config
decoding_chunk_size = cfg.decoding_chunk_size
num_decoding_left_chunks = cfg.num_decoding_left_chunks
assert decoding_chunk_size > 0
subsampling = self.model.encoder.embed.subsampling_rate
context = self.model.encoder.embed.right_context + 1
stride = subsampling * decoding_chunk_size
cached_feature_num = context - subsampling # processed chunk feature cached for next chunk
# decoding window for model
decoding_window = (decoding_chunk_size - 1) * subsampling + context
if self.cached_feat is None:
logger.info("no audio feat, please input more pcm data")
return
num_frames = self.cached_feat.shape[1]
logger.info(
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
)
# the cached feat must be larger decoding_window
if num_frames < decoding_window and not is_finished:
logger.info(
f"frame feat num is less than {decoding_window}, please input more pcm data"
)
return None, None
# if is_finished=True, we need at least context frames
if num_frames < context:
logger.info(
"flast {num_frames} is less than context {context} frames, and we cannot do model forward"
)
return None, None
logger.info("start to do model forward")
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
outputs = []
# num_frames - context + 1 ensure that current frame can get context window
if is_finished:
# if get the finished chunk, we need process the last context
left_frames = context
else:
# we only process decoding_window frames for one chunk
left_frames = decoding_window
# record the end for removing the processed feat
end = None
for cur in range(0, num_frames - left_frames + 1, stride):
end = min(cur + decoding_window, num_frames)
self.chunk_num += 1
chunk_xs = self.cached_feat[:, cur:end, :]
(y, self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size,
self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache)
outputs.append(y)
# update the offset
self.offset += y.shape[1]
ys = paddle.cat(outputs, 1)
if self.encoder_out is None:
self.encoder_out = ys
else:
self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1)
# get the ctc probs
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
self.searcher.search(ctc_probs, self.cached_feat.place)
self.hyps = self.searcher.get_one_best_hyps()
assert self.cached_feat.shape[0] == 1
assert end >= cached_feature_num
self.cached_feat = self.cached_feat[0, end -
cached_feature_num:, :].unsqueeze(0)
assert len(
self.cached_feat.shape
) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
logger.info(
f"This connection handler encoder out shape: {self.encoder_out.shape}"
)
def update_result(self):
logger.info("update the final result")
hyps = self.hyps
self.result_transcripts = [
self.text_feature.defeaturize(hyp) for hyp in hyps
]
self.result_tokenids = [hyp for hyp in hyps]
def get_result(self):
if len(self.result_transcripts) > 0:
return self.result_transcripts[0]
else:
return ''
@paddle.no_grad()
def rescoring(self):
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
return
logger.info("rescoring the final result")
if "attention_rescoring" != self.ctc_decode_config.decoding_method:
return
self.searcher.finalize_search()
self.update_result()
beam_size = self.ctc_decode_config.beam_size
hyps = self.searcher.get_hyps()
if hyps is None or len(hyps) == 0:
return
# assert len(hyps) == beam_size
hyp_list = []
for hyp in hyps:
hyp_content = hyp[0]
# Prevent the hyp is empty
if len(hyp_content) == 0:
hyp_content = (self.model.ctc.blank_id, )
hyp_content = paddle.to_tensor(
hyp_content, place=self.device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id)
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=self.device,
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos,
self.model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = self.encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy()
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for i, hyp in enumerate(hyps):
score = 0.0
for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.model.eos]
# add ctc score (which in ln domain)
score += hyp[1] * self.ctc_decode_config.ctc_weight
if score > best_score:
best_score = score
best_index = i
# update the one best result
logger.info(f"best index: {best_index}")
self.hyps = [hyps[best_index][0]]
self.update_result()
class ASRServerExecutor(ASRExecutor): class ASRServerExecutor(ASRExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
pass pass
def _get_pretrained_path(self, tag: str) -> os.PathLike:
"""
Download and returns pretrained resources path of current task.
"""
support_models = list(pretrained_models.keys())
assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format(
tag, '\n\t\t'.join(support_models))
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
decompressed_path = os.path.abspath(decompressed_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_path(self, def _init_from_path(self,
model_type: str='wenetspeech', model_type: str='deepspeech2online_aishell',
am_model: Optional[os.PathLike]=None, am_model: Optional[os.PathLike]=None,
am_params: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None,
lang: str='zh', lang: str='zh',
@ -71,12 +606,15 @@ class ASRServerExecutor(ASRExecutor):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
self.model_type = model_type
self.sample_rate = sample_rate
if cfg_path is None or am_model is None or am_params is None: if cfg_path is None or am_model is None or am_params is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '-' + lang + '-' + sample_rate_str tag = model_type + '-' + lang + '-' + sample_rate_str
logger.info(f"Load the pretrained model, tag = {tag}")
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path self.res_path = res_path
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path']) pretrained_models[tag]['cfg_path'])
@ -85,9 +623,6 @@ class ASRServerExecutor(ASRExecutor):
self.am_params = os.path.join(res_path, self.am_params = os.path.join(res_path,
pretrained_models[tag]['params']) pretrained_models[tag]['params'])
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path)
logger.info(self.am_model)
logger.info(self.am_params)
else: else:
self.cfg_path = os.path.abspath(cfg_path) self.cfg_path = os.path.abspath(cfg_path)
self.am_model = os.path.abspath(am_model) self.am_model = os.path.abspath(am_model)
@ -95,6 +630,10 @@ class ASRServerExecutor(ASRExecutor):
self.res_path = os.path.dirname( self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.am_model)
logger.info(self.am_params)
#Init body. #Init body.
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
@ -112,15 +651,40 @@ class ASRServerExecutor(ASRExecutor):
lm_url = pretrained_models[tag]['lm_url'] lm_url = pretrained_models[tag]['lm_url']
lm_md5 = pretrained_models[tag]['lm_md5'] lm_md5 = pretrained_models[tag]['lm_md5']
logger.info(f"Start to load language model {lm_url}")
self.download_lm( self.download_lm(
lm_url, lm_url,
os.path.dirname(self.config.decode.lang_model_path), lm_md5) os.path.dirname(self.config.decode.lang_model_path), lm_md5)
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: elif "conformer" in model_type or "transformer" in model_type:
raise Exception("wrong type") logger.info("start to create the stream conformer asr engine")
if self.config.spm_model_prefix:
self.config.spm_model_prefix = os.path.join(
self.res_path, self.config.spm_model_prefix)
self.vocab = self.config.vocab_filepath
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix)
# update the decoding method
if decode_method:
self.config.decode.decoding_method = decode_method
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if self.config.decode.decoding_method not in [
"ctc_prefix_beam_search", "attention_rescoring"
]:
logger.info(
"we set the decoding_method to attention_rescoring")
self.config.decode.decoding = "attention_rescoring"
assert self.config.decode.decoding_method in [
"ctc_prefix_beam_search", "attention_rescoring"
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
else: else:
raise Exception("wrong type") raise Exception("wrong type")
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
# AM predictor # AM predictor
logger.info("ASR engine start to init the am predictor")
self.am_predictor_conf = am_predictor_conf self.am_predictor_conf = am_predictor_conf
self.am_predictor = init_predictor( self.am_predictor = init_predictor(
model_file=self.am_model, model_file=self.am_model,
@ -128,6 +692,7 @@ class ASRServerExecutor(ASRExecutor):
predictor_conf=self.am_predictor_conf) predictor_conf=self.am_predictor_conf)
# decoder # decoder
logger.info("ASR engine start to create the ctc decoder instance")
self.decoder = CTCDecoder( self.decoder = CTCDecoder(
odim=self.config.output_dim, # <blank> is in vocab odim=self.config.output_dim, # <blank> is in vocab
enc_n_units=self.config.rnn_layer_size * 2, enc_n_units=self.config.rnn_layer_size * 2,
@ -138,6 +703,7 @@ class ASRServerExecutor(ASRExecutor):
grad_norm_type=self.config.get('ctc_grad_norm_type', None)) grad_norm_type=self.config.get('ctc_grad_norm_type', None))
# init decoder # init decoder
logger.info("ASR engine start to init the ctc decoder")
cfg = self.config.decode cfg = self.config.decode
decode_batch_size = 1 # for online decode_batch_size = 1 # for online
self.decoder.init_decoder( self.decoder.init_decoder(
@ -153,10 +719,29 @@ class ASRServerExecutor(ASRExecutor):
self.chunk_state_c_box = np.zeros( self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size), (self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32) dtype=float32)
elif "conformer" in model_type or "transformer" in model_type:
model_name = model_type[:model_type.rindex(
'_')] # model_type: {model_name}_{dataset}
logger.info(f"model name: {model_name}")
model_class = dynamic_import(model_name, model_alias)
model_conf = self.config
model = model_class.from_config(model_conf)
self.model = model
self.model.eval()
# load model
model_dict = paddle.load(self.am_model)
self.model.set_state_dict(model_dict)
logger.info("create the transformer like model success")
# update the ctc decoding
self.searcher = CTCPrefixBeamSearch(self.config.decode)
self.transformer_decode_reset()
def reset_decoder_and_chunk(self): def reset_decoder_and_chunk(self):
"""reset decoder and chunk state for an new audio """reset decoder and chunk state for an new audio
""" """
if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type:
self.decoder.reset_decoder(batch_size=1) self.decoder.reset_decoder(batch_size=1)
# init state box, for new audio request # init state box, for new audio request
self.chunk_state_h_box = np.zeros( self.chunk_state_h_box = np.zeros(
@ -165,6 +750,8 @@ class ASRServerExecutor(ASRExecutor):
self.chunk_state_c_box = np.zeros( self.chunk_state_c_box = np.zeros(
(self.config.num_rnn_layers, 1, self.config.rnn_layer_size), (self.config.num_rnn_layers, 1, self.config.rnn_layer_size),
dtype=float32) dtype=float32)
elif "conformer" in self.model_type or "transformer" in self.model_type:
self.transformer_decode_reset()
def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str):
"""decode one chunk """decode one chunk
@ -175,8 +762,9 @@ class ASRServerExecutor(ASRExecutor):
model_type (str): online model type model_type (str): online model type
Returns: Returns:
[type]: [description] str: one best result
""" """
logger.info("start to decoce chunk by chunk")
if "deepspeech2online" in model_type: if "deepspeech2online" in model_type:
input_names = self.am_predictor.get_input_names() input_names = self.am_predictor.get_input_names()
audio_handle = self.am_predictor.get_input_handle(input_names[0]) audio_handle = self.am_predictor.get_input_handle(input_names[0])
@ -215,14 +803,142 @@ class ASRServerExecutor(ASRExecutor):
self.decoder.next(output_chunk_probs, output_chunk_lens) self.decoder.next(output_chunk_probs, output_chunk_lens)
trans_best, trans_beam = self.decoder.decode() trans_best, trans_beam = self.decoder.decode()
logger.info(f"decode one best result: {trans_best[0]}")
return trans_best[0] return trans_best[0]
elif "conformer" in model_type or "transformer" in model_type: elif "conformer" in model_type or "transformer" in model_type:
raise Exception("invalid model name") try:
logger.info(
f"we will use the transformer like model : {self.model_type}"
)
self.advanced_decoding(x_chunk, x_chunk_lens)
self.update_result()
return self.result_transcripts[0]
except Exception as e:
logger.exception(e)
else: else:
raise Exception("invalid model name") raise Exception("invalid model name")
def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens):
logger.info("start to decode with advanced_decoding method")
encoder_out, encoder_mask = self.encoder_forward(xs)
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
self.searcher.search(ctc_probs, xs.place)
# update the one best result
self.hyps = self.searcher.get_one_best_hyps()
# now we supprot ctc_prefix_beam_search and attention_rescoring
if "attention_rescoring" in self.config.decode.decoding_method:
self.rescoring(encoder_out, xs.place)
def encoder_forward(self, xs):
logger.info("get the model out from the feat")
cfg = self.config.decode
decoding_chunk_size = cfg.decoding_chunk_size
num_decoding_left_chunks = cfg.num_decoding_left_chunks
assert decoding_chunk_size > 0
subsampling = self.model.encoder.embed.subsampling_rate
context = self.model.encoder.embed.right_context + 1
stride = subsampling * decoding_chunk_size
# decoding window for model
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
logger.info("start to do model forward")
outputs = []
# num_frames - context + 1 ensure that current frame can get context window
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
(y, self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache) = self.model.encoder.forward_chunk(
chunk_xs, self.offset, required_cache_size,
self.subsampling_cache, self.elayers_output_cache,
self.conformer_cnn_cache)
outputs.append(y)
self.offset += y.shape[1]
ys = paddle.cat(outputs, 1)
masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool)
masks = masks.unsqueeze(1)
return ys, masks
def rescoring(self, encoder_out, device):
logger.info("start to rescoring the hyps")
beam_size = self.config.decode.beam_size
hyps = self.searcher.get_hyps()
assert len(hyps) == beam_size
hyp_list = []
for hyp in hyps:
hyp_content = hyp[0]
# Prevent the hyp is empty
if len(hyp_content) == 0:
hyp_content = (self.model.ctc.blank_id, )
hyp_content = paddle.to_tensor(
hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id)
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos,
self.model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy()
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for i, hyp in enumerate(hyps):
score = 0.0
for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.model.eos]
# add ctc score (which in ln domain)
score += hyp[1] * self.config.decode.ctc_weight
if score > best_score:
best_score = score
best_index = i
# update the one best result
self.hyps = [hyps[best_index][0]]
return hyps[best_index][0]
def transformer_decode_reset(self):
self.subsampling_cache = None
self.elayers_output_cache = None
self.conformer_cnn_cache = None
self.offset = 0
# decoding reset
self.searcher.reset()
def update_result(self):
logger.info("update the final result")
hyps = self.hyps
self.result_transcripts = [
self.text_feature.defeaturize(hyp) for hyp in hyps
]
self.result_tokenids = [hyp for hyp in hyps]
def extract_feat(self, samples, sample_rate): def extract_feat(self, samples, sample_rate):
"""extract feat """extract feat
@ -234,9 +950,10 @@ class ASRServerExecutor(ASRExecutor):
x_chunk (numpy.array): shape[B, T, D] x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B] x_chunk_lens (numpy.array): shape[B]
""" """
if "deepspeech2online" in self.model_type:
# pcm16 -> pcm 32 # pcm16 -> pcm 32
samples = pcm2float(samples) samples = pcm2float(samples)
# read audio # read audio
speech_segment = SpeechSegment.from_pcm( speech_segment = SpeechSegment.from_pcm(
samples, sample_rate, transcript=" ") samples, sample_rate, transcript=" ")
@ -251,7 +968,8 @@ class ASRServerExecutor(ASRExecutor):
spectrum = self.collate_fn_test._normalizer.apply(spectrum) spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment # spectrum augment
audio = self.collate_fn_test.augmentation.transform_feature(spectrum) audio = self.collate_fn_test.augmentation.transform_feature(
spectrum)
audio_len = audio.shape[0] audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32') audio = paddle.to_tensor(audio, dtype='float32')
@ -262,6 +980,28 @@ class ASRServerExecutor(ASRExecutor):
x_chunk_lens = np.array([audio_len]) x_chunk_lens = np.array([audio_len])
return x_chunk, x_chunk_lens return x_chunk, x_chunk_lens
elif "conformer_online" in self.model_type:
if sample_rate != self.sample_rate:
logger.info(f"audio sample rate {sample_rate} is not match,"
"the model sample_rate is {self.sample_rate}")
logger.info(f"ASR Engine use the {self.model_type} to process")
logger.info("Create the preprocess instance")
preprocess_conf = self.config.preprocess_config
preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf)
logger.info("Read the audio file")
logger.info(f"audio shape: {samples.shape}")
# fbank
x_chunk = preprocessing(samples, **preprocess_args)
x_chunk_lens = paddle.to_tensor(x_chunk.shape[0])
x_chunk = paddle.to_tensor(
x_chunk, dtype="float32").unsqueeze(axis=0)
logger.info(
f"process the audio feature success, feat shape: {x_chunk.shape}"
)
return x_chunk, x_chunk_lens
class ASREngine(BaseEngine): class ASREngine(BaseEngine):
@ -273,6 +1013,7 @@ class ASREngine(BaseEngine):
def __init__(self): def __init__(self):
super(ASREngine, self).__init__() super(ASREngine, self).__init__()
logger.info("create the online asr engine instance")
def init(self, config: dict) -> bool: def init(self, config: dict) -> bool:
"""init engine resource """init engine resource
@ -301,7 +1042,10 @@ class ASREngine(BaseEngine):
logger.info("Initialize ASR server engine successfully.") logger.info("Initialize ASR server engine successfully.")
return True return True
def preprocess(self, samples, sample_rate): def preprocess(self,
samples,
sample_rate,
model_type="deepspeech2online_aishell-zh-16k"):
"""preprocess """preprocess
Args: Args:
@ -312,6 +1056,7 @@ class ASREngine(BaseEngine):
x_chunk (numpy.array): shape[B, T, D] x_chunk (numpy.array): shape[B, T, D]
x_chunk_lens (numpy.array): shape[B] x_chunk_lens (numpy.array): shape[B]
""" """
# if "deepspeech" in model_type:
x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate) x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate)
return x_chunk, x_chunk_lens return x_chunk, x_chunk_lens

@ -0,0 +1,128 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
from paddlespeech.cli.log import logger
from paddlespeech.s2t.utils.utility import log_add
__all__ = ['CTCPrefixBeamSearch']
class CTCPrefixBeamSearch:
def __init__(self, config):
"""Implement the ctc prefix beam search
Args:
config (yacs.config.CfgNode): _description_
"""
self.config = config
self.reset()
@paddle.no_grad()
def search(self, ctc_probs, device, blank_id=0):
"""ctc prefix beam search method decode a chunk feature
Args:
xs (paddle.Tensor): feature data
ctc_probs (paddle.Tensor): the ctc probability of all the tokens
device (paddle.fluid.core_avx.Place): the feature host device, such as CUDAPlace(0).
blank_id (int, optional): the blank id in the vocab. Defaults to 0.
Returns:
list: the search result
"""
# decode
logger.info("start to ctc prefix search")
batch_size = 1
beam_size = self.config.beam_size
maxlen = ctc_probs.shape[0]
assert len(ctc_probs.shape) == 2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if self.cur_hyps is None:
self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in self.cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
self.cur_hyps = next_hyps[:beam_size]
self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps]
logger.info("ctc prefix search success")
return self.hyps
def get_one_best_hyps(self):
"""Return the one best result
Returns:
list: the one best result
"""
return [self.hyps[0][0]]
def get_hyps(self):
"""Return the search hyps
Returns:
list: return the search hyps
"""
return self.hyps
def reset(self):
"""Rest the search cache value
"""
self.cur_hyps = None
self.hyps = None
def finalize_search(self):
"""do nothing in ctc_prefix_beam_search
"""
pass

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -34,9 +34,8 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str): def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples) x_len = len(samples)
# chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
chunk_size = 85 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size!= 0: if x_len % chunk_size!= 0:
padding_len_x = chunk_size - x_len % chunk_size padding_len_x = chunk_size - x_len % chunk_size
else: else:
@ -48,7 +47,6 @@ class ASRAudioHandler:
assert (x_len + padding_len_x) % chunk_size == 0 assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk) num_chunk = int(num_chunk)
for i in range(0, num_chunk): for i in range(0, num_chunk):
start = i * chunk_size start = i * chunk_size
end = start + chunk_size end = start + chunk_size
@ -57,7 +55,11 @@ class ASRAudioHandler:
async def run(self, wavfile_path: str): async def run(self, wavfile_path: str):
logging.info("send a message to the server") logging.info("send a message to the server")
# self.read_wave()
# send websocket handshake protocal
async with websockets.connect(self.url) as ws: async with websockets.connect(self.url) as ws:
# server has already received handshake protocal
# client start to send the command
audio_info = json.dumps( audio_info = json.dumps(
{ {
"name": "test.wav", "name": "test.wav",
@ -78,7 +80,6 @@ class ASRAudioHandler:
msg = json.loads(msg) msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
result = msg
# finished # finished
audio_info = json.dumps( audio_info = json.dumps(
{ {
@ -91,9 +92,11 @@ class ASRAudioHandler:
separators=(',', ': ')) separators=(',', ': '))
await ws.send(audio_info) await ws.send(audio_info)
msg = await ws.recv() msg = await ws.recv()
msg = json.loads(msg)
logging.info("receive msg={}".format(msg))
# decode the bytes to str
msg = json.loads(msg)
logging.info("final receive msg={}".format(msg))
result = msg
return result return result

@ -63,12 +63,12 @@ class ChunkBuffer(object):
the sample rate. the sample rate.
Yields Frames of the requested duration. Yields Frames of the requested duration.
""" """
audio = self.remained_audio + audio audio = self.remained_audio + audio
self.remained_audio = b'' self.remained_audio = b''
offset = 0 offset = 0
timestamp = 0.0 timestamp = 0.0
while offset + self.window_bytes <= len(audio): while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp, yield Frame(audio[offset:offset + self.window_bytes], timestamp,
self.window_sec) self.window_sec)

@ -13,12 +13,12 @@
# limitations under the License. # limitations under the License.
import json import json
import numpy as np
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import WebSocket from fastapi import WebSocket
from fastapi import WebSocketDisconnect from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.engine.engine_pool import get_engine_pool
from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.buffer import ChunkBuffer
from paddlespeech.server.utils.vad import VADAudio from paddlespeech.server.utils.vad import VADAudio
@ -28,22 +28,25 @@ router = APIRouter()
@router.websocket('/ws/asr') @router.websocket('/ws/asr')
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
await websocket.accept() await websocket.accept()
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
connection_handler = None
# init buffer # init buffer
# each websocekt connection has its own chunk buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer( chunk_buffer = ChunkBuffer(
window_n=7, window_n=chunk_buffer_conf.window_n,
shift_n=4, shift_n=chunk_buffer_conf.shift_n,
window_ms=20, window_ms=chunk_buffer_conf.window_ms,
shift_ms=10, shift_ms=chunk_buffer_conf.shift_ms,
sample_rate=chunk_buffer_conf['sample_rate'], sample_rate=chunk_buffer_conf.sample_rate,
sample_width=chunk_buffer_conf['sample_width']) sample_width=chunk_buffer_conf.sample_width)
# init vad # init vad
vad_conf = asr_engine.config.vad_conf vad_conf = asr_engine.config.get('vad_conf', None)
if vad_conf:
vad = VADAudio( vad = VADAudio(
aggressiveness=vad_conf['aggressiveness'], aggressiveness=vad_conf['aggressiveness'],
rate=vad_conf['sample_rate'], rate=vad_conf['sample_rate'],
@ -64,13 +67,21 @@ async def websocket_endpoint(websocket: WebSocket):
if message['signal'] == 'start': if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"} resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here # do something at begining here
# create the instance to process the audio
connection_handler = PaddleASRConnectionHanddler(asr_engine)
await websocket.send_json(resp) await websocket.send_json(resp)
elif message['signal'] == 'end': elif message['signal'] == 'end':
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
# reset single engine for an new connection # reset single engine for an new connection
asr_engine.reset() connection_handler.decode(is_finished=True)
resp = {"status": "ok", "signal": "finished"} connection_handler.rescoring()
asr_results = connection_handler.get_result()
connection_handler.reset()
resp = {
"status": "ok",
"signal": "finished",
'asr_results': asr_results
}
await websocket.send_json(resp) await websocket.send_json(resp)
break break
else: else:
@ -79,21 +90,11 @@ async def websocket_endpoint(websocket: WebSocket):
elif "bytes" in message: elif "bytes" in message:
message = message["bytes"] message = message["bytes"]
engine_pool = get_engine_pool() connection_handler.extract_feat(message)
asr_engine = engine_pool['asr'] connection_handler.decode(is_finished=False)
asr_results = "" asr_results = connection_handler.get_result()
frames = chunk_buffer.frame_generator(message)
for frame in frames:
samples = np.frombuffer(frame.bytes, dtype=np.int16)
sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate)
asr_engine.run(x_chunk, x_chunk_lens)
asr_results = asr_engine.postprocess()
asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results} resp = {'asr_results': asr_results}
await websocket.send_json(resp) await websocket.send_json(resp)
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass

Loading…
Cancel
Save