update the attention_rescoring method, test=doc

pull/1704/head
xiongxinlei 2 years ago
parent 0c5dbbee5b
commit 97d31f9aac

@ -71,15 +71,9 @@ asr_online:
summary: True # False -> do not show predictor config summary: True # False -> do not show predictor config
chunk_buffer_conf: chunk_buffer_conf:
frame_duration_ms: 85 window_n: 7 # frame
shift_ms: 40 shift_n: 4 # frame
window_ms: 25 # ms
shift_ms: 10 # ms
sample_rate: 16000 sample_rate: 16000
sample_width: 2 sample_width: 2
# vad_conf:
# aggressiveness: 2
# sample_rate: 16000
# frame_duration_ms: 20
# sample_width: 2
# padding_ms: 200
# padding_ratio: 0.9

@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
from paddlespeech.s2t.utils.utility import log_add
from typing import Optional from typing import Optional
from collections import defaultdict
import numpy as np import numpy as np
import paddle import paddle
from numpy import float32 from numpy import float32
@ -22,19 +21,18 @@ from yacs.config import CfgNode
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.asr.infer import model_alias from paddlespeech.cli.asr.infer import model_alias
from paddlespeech.cli.asr.infer import pretrained_models
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.cli.utils import download_and_decompress from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.cli.utils import MODEL_HOME
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.modules.mask import mask_finished_preds
from paddlespeech.s2t.modules.mask import mask_finished_scores
from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
from paddlespeech.s2t.utils.tensor_utils import pad_sequence
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.audio_process import pcm2float
from paddlespeech.server.utils.paddle_predictor import init_predictor from paddlespeech.server.utils.paddle_predictor import init_predictor
@ -62,9 +60,9 @@ pretrained_models = {
}, },
"conformer2online_aishell-zh-16k": { "conformer2online_aishell-zh-16k": {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz', 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.1.model.tar.gz',
'md5': 'md5':
'7989b3248c898070904cf042fd656003', 'b450d5dfaea0ac227c595ce58d18b637',
'cfg_path': 'cfg_path':
'model.yaml', 'model.yaml',
'ckpt_path': 'ckpt_path':
@ -123,9 +121,9 @@ class ASRServerExecutor(ASRExecutor):
logger.info(f"Load the pretrained model, tag = {tag}") logger.info(f"Load the pretrained model, tag = {tag}")
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path self.res_path = res_path
self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml" # self.cfg_path = "/home/users/xiongxinlei/task/paddlespeech-develop/PaddleSpeech/examples/aishell/asr1/model.yaml"
# self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
# pretrained_models[tag]['cfg_path']) pretrained_models[tag]['cfg_path'])
self.am_model = os.path.join(res_path, self.am_model = os.path.join(res_path,
pretrained_models[tag]['model']) pretrained_models[tag]['model'])
@ -177,6 +175,18 @@ class ASRServerExecutor(ASRExecutor):
# update the decoding method # update the decoding method
if decode_method: if decode_method:
self.config.decode.decoding_method = decode_method self.config.decode.decoding_method = decode_method
# we only support ctc_prefix_beam_search and attention_rescoring dedoding method
# Generally we set the decoding_method to attention_rescoring
if self.config.decode.decoding_method not in [
"ctc_prefix_beam_search", "attention_rescoring"
]:
logger.info(
"we set the decoding_method to attention_rescoring")
self.config.decode.decoding = "attention_rescoring"
assert self.config.decode.decoding_method in [
"ctc_prefix_beam_search", "attention_rescoring"
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
else: else:
raise Exception("wrong type") raise Exception("wrong type")
if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: if "deepspeech2online" in model_type or "deepspeech2offline" in model_type:
@ -232,7 +242,7 @@ class ASRServerExecutor(ASRExecutor):
logger.info("create the transformer like model success") logger.info("create the transformer like model success")
# update the ctc decoding # update the ctc decoding
self.searcher = None self.searcher = CTCPrefixBeamSearch(self.config.decode)
self.transformer_decode_reset() self.transformer_decode_reset()
def reset_decoder_and_chunk(self): def reset_decoder_and_chunk(self):
@ -320,7 +330,16 @@ class ASRServerExecutor(ASRExecutor):
def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens):
logger.info("start to decode with advanced_decoding method") logger.info("start to decode with advanced_decoding method")
encoder_out, encoder_mask = self.decode_forward(xs) encoder_out, encoder_mask = self.decode_forward(xs)
self.ctc_prefix_beam_search(xs, encoder_out, encoder_mask) ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
self.searcher.search(xs, ctc_probs, xs.place)
# update the one best result
self.hyps = self.searcher.get_one_best_hyps()
# now we supprot ctc_prefix_beam_search and attention_rescoring
if "attention_rescoring" in self.config.decode.decoding_method:
self.rescoring(encoder_out, xs.place)
def decode_forward(self, xs): def decode_forward(self, xs):
logger.info("get the model out from the feat") logger.info("get the model out from the feat")
@ -338,7 +357,6 @@ class ASRServerExecutor(ASRExecutor):
num_frames = xs.shape[1] num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks required_cache_size = decoding_chunk_size * num_decoding_left_chunks
logger.info("start to do model forward") logger.info("start to do model forward")
outputs = [] outputs = []
@ -359,85 +377,74 @@ class ASRServerExecutor(ASRExecutor):
masks = masks.unsqueeze(1) masks = masks.unsqueeze(1)
return ys, masks return ys, masks
def rescoring(self, encoder_out, device):
logger.info("start to rescoring the hyps")
beam_size = self.config.decode.beam_size
hyps = self.searcher.get_hyps()
assert len(hyps) == beam_size
hyp_list = []
for hyp in hyps:
hyp_content = hyp[0]
# Prevent the hyp is empty
if len(hyp_content) == 0:
hyp_content = (self.model.ctc.blank_id, )
hyp_content = paddle.to_tensor(
hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id)
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos,
self.model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy()
# Only use decoder score for rescoring
best_score = -float('inf')
best_index = 0
# hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size
for i, hyp in enumerate(hyps):
score = 0.0
for j, w in enumerate(hyp[0]):
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.model.eos]
# add ctc score (which in ln domain)
score += hyp[1] * self.config.decode.ctc_weight
if score > best_score:
best_score = score
best_index = i
# update the one best result
self.hyps = [hyps[best_index][0]]
return hyps[best_index][0]
def transformer_decode_reset(self): def transformer_decode_reset(self):
self.subsampling_cache = None self.subsampling_cache = None
self.elayers_output_cache = None self.elayers_output_cache = None
self.conformer_cnn_cache = None self.conformer_cnn_cache = None
self.hyps = None
self.offset = 0 self.offset = 0
self.cur_hyps = None # decoding reset
self.hyps = None self.searcher.reset()
def ctc_prefix_beam_search(self, xs, encoder_out, encoder_mask, blank_id=0):
# decode
logger.info("start to ctc prefix search")
device = xs.place
cfg = self.config.decode
batch_size = xs.shape[0]
beam_size = cfg.beam_size
maxlen = encoder_out.shape[1]
ctc_probs = self.model.ctc.log_softmax(encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if self.cur_hyps is None:
self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in self.cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
self.cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps]
self.hyps = [hyps[0][0]]
logger.info("ctc prefix search success")
return hyps, encoder_out
def update_result(self): def update_result(self):
logger.info("update the final result") logger.info("update the final result")
hyps = self.hyps
self.result_transcripts = [ self.result_transcripts = [
self.text_feature.defeaturize(hyp) for hyp in self.hyps self.text_feature.defeaturize(hyp) for hyp in hyps
] ]
self.result_tokenids = [hyp for hyp in self.hyps] self.result_tokenids = [hyp for hyp in hyps]
def extract_feat(self, samples, sample_rate): def extract_feat(self, samples, sample_rate):
"""extract feat """extract feat
@ -483,9 +490,9 @@ class ASRServerExecutor(ASRExecutor):
elif "conformer2online" in self.model_type: elif "conformer2online" in self.model_type:
if sample_rate != self.sample_rate: if sample_rate != self.sample_rate:
logger.info(f"audio sample rate {sample_rate} is not match," \ logger.info(f"audio sample rate {sample_rate} is not match,"
"the model sample_rate is {self.sample_rate}") "the model sample_rate is {self.sample_rate}")
logger.info(f"ASR Engine use the {self.model_type} to process") logger.info("ASR Engine use the {self.model_type} to process")
logger.info("Create the preprocess instance") logger.info("Create the preprocess instance")
preprocess_conf = self.config.preprocess_config preprocess_conf = self.config.preprocess_config
preprocess_args = {"train": False} preprocess_args = {"train": False}

@ -0,0 +1,119 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from paddlespeech.cli.log import logger
from paddlespeech.s2t.utils.utility import log_add
__all__ = ['CTCPrefixBeamSearch']
class CTCPrefixBeamSearch:
def __init__(self, config):
"""Implement the ctc prefix beam search
Args:
config (_type_): _description_
"""
self.config = config
self.reset()
def search(self, xs, ctc_probs, device, blank_id=0):
"""ctc prefix beam search method decode a chunk feature
Args:
xs (paddle.Tensor): feature data
ctc_probs (paddle.Tensor): the ctc probability of all the tokens
encoder_out (paddle.Tensor): _description_
encoder_mask (_type_): _description_
blank_id (int, optional): the blank id in the vocab. Defaults to 0.
Returns:
list: the search result
"""
# decode
logger.info("start to ctc prefix search")
# device = xs.place
batch_size = xs.shape[0]
beam_size = self.config.beam_size
maxlen = ctc_probs.shape[0]
assert len(ctc_probs.shape) == 2
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
if self.cur_hyps is None:
self.cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
# do token passing process
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in self.cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
self.cur_hyps = next_hyps[:beam_size]
self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps]
logger.info("ctc prefix search success")
return self.hyps
def get_one_best_hyps(self):
"""Return the one best result
Returns:
list: the one best result
"""
return [self.hyps[0][0]]
def get_hyps(self):
return self.hyps
def reset(self):
"""Rest the search cache value
"""
self.cur_hyps = None
self.hyps = None

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -34,17 +34,17 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
# init buffer # init buffer
# each websocekt connection has its own chunk buffer
chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer_conf = asr_engine.config.chunk_buffer_conf
chunk_buffer = ChunkBuffer( chunk_buffer = ChunkBuffer(
window_n=7, window_n=chunk_buffer_conf.window_n,
shift_n=4, shift_n=chunk_buffer_conf.shift_n,
window_ms=20, window_ms=chunk_buffer_conf.window_ms,
shift_ms=10, shift_ms=chunk_buffer_conf.shift_ms,
sample_rate=chunk_buffer_conf['sample_rate'], sample_rate=chunk_buffer_conf.sample_rate,
sample_width=chunk_buffer_conf['sample_width']) sample_width=chunk_buffer_conf.sample_width)
# init vad # init vad
# print(asr_engine.config)
# print(type(asr_engine.config))
vad_conf = asr_engine.config.get('vad_conf', None) vad_conf = asr_engine.config.get('vad_conf', None)
if vad_conf: if vad_conf:
vad = VADAudio( vad = VADAudio(
@ -72,7 +72,7 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
# reset single engine for an new connection # reset single engine for an new connection
# asr_engine.reset() asr_engine.reset()
resp = {"status": "ok", "signal": "finished"} resp = {"status": "ok", "signal": "finished"}
await websocket.send_json(resp) await websocket.send_json(resp)
break break
@ -85,21 +85,16 @@ async def websocket_endpoint(websocket: WebSocket):
engine_pool = get_engine_pool() engine_pool = get_engine_pool()
asr_engine = engine_pool['asr'] asr_engine = engine_pool['asr']
asr_results = "" asr_results = ""
# frames = chunk_buffer.frame_generator(message) frames = chunk_buffer.frame_generator(message)
# for frame in frames: for frame in frames:
# # get the pcm data from the bytes # get the pcm data from the bytes
# samples = np.frombuffer(frame.bytes, dtype=np.int16) samples = np.frombuffer(frame.bytes, dtype=np.int16)
# sample_rate = asr_engine.config.sample_rate sample_rate = asr_engine.config.sample_rate
# x_chunk, x_chunk_lens = asr_engine.preprocess(samples, x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
# sample_rate) sample_rate)
# asr_engine.run(x_chunk, x_chunk_lens) asr_engine.run(x_chunk, x_chunk_lens)
# asr_results = asr_engine.postprocess() asr_results = asr_engine.postprocess()
samples = np.frombuffer(message, dtype=np.int16)
sample_rate = asr_engine.config.sample_rate
x_chunk, x_chunk_lens = asr_engine.preprocess(samples,
sample_rate)
asr_engine.run(x_chunk, x_chunk_lens)
# asr_results = asr_engine.postprocess()
asr_results = asr_engine.postprocess() asr_results = asr_engine.postprocess()
resp = {'asr_results': asr_results} resp = {'asr_results': asr_results}

Loading…
Cancel
Save