Merge branch 'develop' into u2pp_export

pull/2425/head
Hui Zhang 3 years ago
commit 2a75405e9a

@ -21,7 +21,7 @@ engine_list: ['asr_online']
################################### ASR ######################################### ################################### ASR #########################################
################### speech task: asr; engine_type: online ####################### ################### speech task: asr; engine_type: online #######################
asr_online: asr_online:
model_type: 'conformer_online_wenetspeech' model_type: 'conformer_u2pp_online_wenetspeech'
am_model: # the pdmodel file of am static model [optional] am_model: # the pdmodel file of am static model [optional]
am_params: # the pdiparams file of am static model [optional] am_params: # the pdiparams file of am static model [optional]
lang: 'zh' lang: 'zh'

@ -9,6 +9,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
[Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz) | Aishell Dataset | Char-based | 491 MB | 2 Conv + 5 LSTM layers | 0.0666 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0) | onnx/inference/python | [Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz) | Aishell Dataset | Char-based | 491 MB | 2 Conv + 5 LSTM layers | 0.0666 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0) | onnx/inference/python |
[Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) | inference/python | [Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) | inference/python |
[Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- | python | [Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- | python |
[Conformer U2PP Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.1.model.tar.gz) | WenetSpeech Dataset | Char-based | 476 MB | Encoder:Conformer, Decoder:BiTransformer, Decoding method: Attention rescoring| 0.047198 (aishell test\_-1) 0.059212 (aishell test\_16) |-| 10000 h |- | python |
[Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python | [Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python |
[Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.0.1.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0460 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python | [Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.0.1.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0460 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python |
[Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) | python | [Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) | python |

@ -1,11 +1,12 @@
beam_size: 10 beam_size: 10
decode_batch_size: 128
error_rate_type: cer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
reverse_weight: 0.3 # reverse weight for attention rescoring decode mode.
decoding_chunk_size: 16 # decoding chunk size. Defaults to -1. decoding_chunk_size: 16 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk. # <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set. # >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here. # 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: True # simulate streaming inference. Defaults to False. simulate_streaming: True # simulate streaming inference. Defaults to False.
decode_batch_size: 128
error_rate_type: cer

@ -1,11 +1,12 @@
decode_batch_size: 128
error_rate_type: cer
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
beam_size: 10 beam_size: 10
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
reverse_weight: 0.3 # reverse weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk. # <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set. # >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here. # 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False. simulate_streaming: False # simulate streaming inference. Defaults to False.
decode_batch_size: 128
error_rate_type: cer

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import io
import os import os
import sys import sys
import time import time
@ -51,7 +52,7 @@ class ASRExecutor(BaseExecutor):
self.parser.add_argument( self.parser.add_argument(
'--model', '--model',
type=str, type=str,
default='conformer_wenetspeech', default='conformer_u2pp_wenetspeech',
choices=[ choices=[
tag[:tag.index('-')] tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys() for tag in self.task_resource.pretrained_models.keys()
@ -229,6 +230,8 @@ class ASRExecutor(BaseExecutor):
audio_file = input audio_file = input
if isinstance(audio_file, (str, os.PathLike)): if isinstance(audio_file, (str, os.PathLike)):
logger.debug("Preprocess audio_file:" + audio_file) logger.debug("Preprocess audio_file:" + audio_file)
elif isinstance(audio_file, io.BytesIO):
audio_file.seek(0)
# Get the object for feature extraction # Get the object for feature extraction
if "deepspeech2" in model_type or "conformer" in model_type or "transformer" in model_type: if "deepspeech2" in model_type or "conformer" in model_type or "transformer" in model_type:
@ -352,6 +355,8 @@ class ASRExecutor(BaseExecutor):
if not os.path.isfile(audio_file): if not os.path.isfile(audio_file):
logger.error("Please input the right audio file path") logger.error("Please input the right audio file path")
return False return False
elif isinstance(audio_file, io.BytesIO):
audio_file.seek(0)
logger.debug("checking the audio file format......") logger.debug("checking the audio file format......")
try: try:
@ -465,7 +470,7 @@ class ASRExecutor(BaseExecutor):
@stats_wrapper @stats_wrapper
def __call__(self, def __call__(self,
audio_file: os.PathLike, audio_file: os.PathLike,
model: str='conformer_wenetspeech', model: str='conformer_u2pp_wenetspeech',
lang: str='zh', lang: str='zh',
sample_rate: int=16000, sample_rate: int=16000,
config: os.PathLike=None, config: os.PathLike=None,

@ -25,6 +25,8 @@ model_alias = {
"deepspeech2online": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"], "deepspeech2online": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"],
"conformer": ["paddlespeech.s2t.models.u2:U2Model"], "conformer": ["paddlespeech.s2t.models.u2:U2Model"],
"conformer_online": ["paddlespeech.s2t.models.u2:U2Model"], "conformer_online": ["paddlespeech.s2t.models.u2:U2Model"],
"conformer_u2pp": ["paddlespeech.s2t.models.u2:U2Model"],
"conformer_u2pp_online": ["paddlespeech.s2t.models.u2:U2Model"],
"transformer": ["paddlespeech.s2t.models.u2:U2Model"], "transformer": ["paddlespeech.s2t.models.u2:U2Model"],
"wenetspeech": ["paddlespeech.s2t.models.u2:U2Model"], "wenetspeech": ["paddlespeech.s2t.models.u2:U2Model"],

@ -68,6 +68,46 @@ asr_dynamic_pretrained_models = {
'', '',
}, },
}, },
"conformer_u2pp_wenetspeech-zh-16k": {
'1.1': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.1.model.tar.gz',
'md5':
'eae678c04ed3b3f89672052fdc0c5e10',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer_u2pp/checkpoints/avg_10',
'model':
'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams',
'params':
'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams',
'lm_url':
'',
'lm_md5':
'',
},
},
"conformer_u2pp_online_wenetspeech-zh-16k": {
'1.1': {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.2.model.tar.gz',
'md5':
'925d047e9188dea7f421a718230c9ae3',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/chunk_conformer_u2pp/checkpoints/avg_10',
'model':
'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams',
'params':
'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams',
'lm_url':
'',
'lm_md5':
'',
},
},
"conformer_online_multicn-zh-16k": { "conformer_online_multicn-zh-16k": {
'1.0': { '1.0': {
'url': 'url':

@ -62,7 +62,6 @@ class U2Infer():
params_path = self.args.checkpoint_path + ".pdparams" params_path = self.args.checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path) model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
logger.info(f"model_dict: {model_dict.keys()}")
def run(self): def run(self):
check(args.audio_file) check(args.audio_file)
@ -91,22 +90,22 @@ class U2Infer():
ctc_weight=decode_config.ctc_weight, ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size, decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks, num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming, simulate_streaming=decode_config.simulate_streaming
reverse_weight=self.reverse_weight) reverse_weight=decode_config.reverse_weight)
rsl = result_transcripts[0][0] rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}") logger.info(f"hyp: {utt} {rsl}")
# print(self.model) # print(self.model)
# print(self.model.forward_encoder_chunk) # print(self.model.forward_encoder_chunk)
# return rsl
logger.info("-------------start export ----------------------") logger.info("-------------start quant ----------------------")
batch_size = 1 batch_size = 1
feat_dim = 80 feat_dim = 80
model_size = 512 model_size = 512
num_left_chunks = -1 num_left_chunks = -1
reverse_weight = 0.3
logger.info( logger.info(
f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}" f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}, reverse_weight {reverse_weight}"
) )
# ######################## self.model.forward_encoder_chunk ############ # ######################## self.model.forward_encoder_chunk ############
@ -146,7 +145,6 @@ class U2Infer():
self.model.ctc_activation, input_spec=input_spec) self.model.ctc_activation, input_spec=input_spec)
######################### self.model.forward_attention_decoder ######################## ######################### self.model.forward_attention_decoder ########################
reverse_weight = 0.3
input_spec = [ input_spec = [
# hyps, (B, U) # hyps, (B, U)
paddle.static.InputSpec(shape=[None, None], dtype='int64'), paddle.static.InputSpec(shape=[None, None], dtype='int64'),

@ -40,7 +40,6 @@ class U2Infer():
self.preprocess_conf = config.preprocess_config self.preprocess_conf = config.preprocess_config
self.preprocess_args = {"train": False} self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf) self.preprocessing = Transformation(self.preprocess_conf)
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=config.unit_type, unit_type=config.unit_type,
vocab=config.vocab_filepath, vocab=config.vocab_filepath,
@ -90,7 +89,7 @@ class U2Infer():
decoding_chunk_size=decode_config.decoding_chunk_size, decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks, num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming, simulate_streaming=decode_config.simulate_streaming,
reverse_weight=self.reverse_weight) reverse_weight=decode_config.reverse_weight)
rsl = result_transcripts[0][0] rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}") logger.info(f"hyp: {utt} {result_transcripts[0][0]}")

@ -316,7 +316,6 @@ class U2Tester(U2Trainer):
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list self.vocab_list = self.text_feature.vocab_list
self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0)
def id2token(self, texts, texts_len, text_feature): def id2token(self, texts, texts_len, text_feature):
""" ord() id to chr() chr """ """ ord() id to chr() chr """
@ -351,8 +350,8 @@ class U2Tester(U2Trainer):
ctc_weight=decode_config.ctc_weight, ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size, decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks, num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming, simulate_streaming=decode_config.simulate_streaming
reverse_weight=self.reverse_weight) reverse_weight=decode_config.reverse_weight)
decode_time = time.time() - start_time decode_time = time.time() - start_time
for utt, target, result, rec_tids in zip( for utt, target, result, rec_tids in zip(

@ -131,7 +131,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
if self.ctc_weight != 1.0: if self.ctc_weight != 1.0:
start = time.time() start = time.time()
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
text, text_lengths) text, text_lengths, self.reverse_weight)
decoder_time = time.time() - start decoder_time = time.time() - start
#logger.debug(f"decoder time: {decoder_time}") #logger.debug(f"decoder time: {decoder_time}")
@ -157,7 +157,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
encoder_out: paddle.Tensor, encoder_out: paddle.Tensor,
encoder_mask: paddle.Tensor, encoder_mask: paddle.Tensor,
ys_pad: paddle.Tensor, ys_pad: paddle.Tensor,
ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]: ys_pad_lens: paddle.Tensor,
reverse_weight: float) -> Tuple[paddle.Tensor, float]:
"""Calc attention loss. """Calc attention loss.
Args: Args:
@ -165,6 +166,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
encoder_mask (paddle.Tensor): [B, 1, Tmax] encoder_mask (paddle.Tensor): [B, 1, Tmax]
ys_pad (paddle.Tensor): [B, Umax] ys_pad (paddle.Tensor): [B, Umax]
ys_pad_lens (paddle.Tensor): [B] ys_pad_lens (paddle.Tensor): [B]
reverse_weight (float): reverse decoder weight.
Returns: Returns:
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
@ -179,15 +181,15 @@ class U2BaseModel(ASRInterface, nn.Layer):
# 1. Forward decoder # 1. Forward decoder
decoder_out, r_decoder_out, _ = self.decoder( decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad, encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad,
self.reverse_weight) reverse_weight)
# 2. Compute attention loss # 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad) loss_att = self.criterion_att(decoder_out, ys_out_pad)
r_loss_att = paddle.to_tensor(0.0) r_loss_att = paddle.to_tensor(0.0)
if self.reverse_weight > 0.0: if reverse_weight > 0.0:
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
loss_att = loss_att * (1 - self.reverse_weight loss_att = loss_att * (1 - reverse_weight
) + r_loss_att * self.reverse_weight ) + r_loss_att * reverse_weight
acc_att = th_accuracy( acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size), decoder_out.view(-1, self.vocab_size),
ys_out_pad, ys_out_pad,
@ -501,16 +503,15 @@ class U2BaseModel(ASRInterface, nn.Layer):
num_decoding_left_chunks, simulate_streaming) num_decoding_left_chunks, simulate_streaming)
return hyps[0][0] return hyps[0][0]
def attention_rescoring( def attention_rescoring(self,
self, speech: paddle.Tensor,
speech: paddle.Tensor, speech_lengths: paddle.Tensor,
speech_lengths: paddle.Tensor, beam_size: int,
beam_size: int, decoding_chunk_size: int=-1,
decoding_chunk_size: int=-1, num_decoding_left_chunks: int=-1,
num_decoding_left_chunks: int=-1, ctc_weight: float=0.0,
ctc_weight: float=0.0, simulate_streaming: bool=False,
simulate_streaming: bool=False, reverse_weight: float=0.0) -> List[int]:
reverse_weight: float=0.0, ) -> List[int]:
""" Apply attention rescoring decoding, CTC prefix beam search """ Apply attention rescoring decoding, CTC prefix beam search
is applied first to get nbest, then we resoring the nbest on is applied first to get nbest, then we resoring the nbest on
attention decoder with corresponding encoder out attention decoder with corresponding encoder out
@ -525,6 +526,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
0: used for training, it's prohibited here 0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a simulate_streaming (bool): whether do encoder forward in a
streaming fashion streaming fashion
reverse_weight (float): reverse deocder weight.
Returns: Returns:
List[int]: Attention rescoring result List[int]: Attention rescoring result
""" """
@ -554,14 +556,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
hyp_content, place=device, dtype=paddle.long) hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content) hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.ignore_id) hyps_pad = pad_sequence(hyp_list, True, self.ignore_id)
ori_hyps_pad = hyps_pad
hyps_lens = paddle.to_tensor( hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=device, [len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,) dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining
logger.debug( logger.debug(
f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}") f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}")
hyps_lens = hyps_lens + 1 # Add <sos> at begining
# ctc score in ln domain # ctc score in ln domain
# (beam_size, max_hyps_len, vocab_size) # (beam_size, max_hyps_len, vocab_size)
@ -598,8 +599,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}" f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}"
) )
score = score * (1 - reverse_weight) + r_score * reverse_weight score = score * (1 - reverse_weight
) + r_score * reverse_weight
# add ctc score (which in ln domain) # add ctc score (which in ln domain)
score += hyp[1] * ctc_weight score += hyp[1] * ctc_weight
if score > best_score: if score > best_score:
@ -769,6 +770,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
num_decoding_left_chunks (int, optional): num_decoding_left_chunks (int, optional):
number of left chunks for decoding. Defaults to -1. number of left chunks for decoding. Defaults to -1.
simulate_streaming (bool, optional): simulate streaming inference. Defaults to False. simulate_streaming (bool, optional): simulate streaming inference. Defaults to False.
reverse_weight (float, optional): reverse decoder weight, used by `attention_rescoring`.
Raises: Raises:
ValueError: when not support decoding_method. ValueError: when not support decoding_method.

@ -30,7 +30,7 @@ asr_online:
decode_method: decode_method:
num_decoding_left_chunks: -1 num_decoding_left_chunks: -1
force_yes: True force_yes: True
device: # cpu or gpu:id device: cpu # cpu or gpu:id
continuous_decoding: True # enable continue decoding when endpoint detected continuous_decoding: True # enable continue decoding when endpoint detected
am_predictor_conf: am_predictor_conf:

@ -22,6 +22,7 @@ from numpy import float32
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource from paddlespeech.resource import CommonTaskResource
@ -602,6 +603,7 @@ class PaddleASRConnectionHanddler:
hyps_pad = pad_sequence( hyps_pad = pad_sequence(
hyp_list, batch_first=True, padding_value=self.model.ignore_id) hyp_list, batch_first=True, padding_value=self.model.ignore_id)
ori_hyps_pad = hyps_pad
hyps_lens = paddle.to_tensor( hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=self.device, [len(hyp[0]) for hyp in hyps], place=self.device,
dtype=paddle.long) # (beam_size,) dtype=paddle.long) # (beam_size,)
@ -609,16 +611,15 @@ class PaddleASRConnectionHanddler:
self.model.ignore_id) self.model.ignore_id)
hyps_lens = hyps_lens + 1 # Add <sos> at begining hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = self.encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _, _ = self.model.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain # ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) # (beam_size, max_hyps_len, vocab_size)
decoder_out, r_decoder_out = self.model.forward_attention_decoder(
hyps_pad, hyps_lens, self.encoder_out, self.model.reverse_weight)
decoder_out = decoder_out.numpy() decoder_out = decoder_out.numpy()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out = r_decoder_out.numpy()
# Only use decoder score for rescoring # Only use decoder score for rescoring
best_score = -float('inf') best_score = -float('inf')
@ -631,6 +632,13 @@ class PaddleASRConnectionHanddler:
# last decoder output token is `eos`, for laste decoder input token. # last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.model.eos] score += decoder_out[i][len(hyp[0])][self.model.eos]
if self.model.reverse_weight > 0:
r_score = 0.0
for j, w in enumerate(hyp[0]):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.model.eos]
score = score * (1 - self.model.reverse_weight
) + r_score * self.model.reverse_weight
# add ctc score (which in ln domain) # add ctc score (which in ln domain)
score += hyp[1] * self.ctc_decode_config.ctc_weight score += hyp[1] * self.ctc_decode_config.ctc_weight

@ -107,11 +107,14 @@ class PaddleTextConnectionHandler:
assert len(tokens) == len(labels) assert len(tokens) == len(labels)
text = '' text = ''
is_fast_model = 'fast' in self.text_engine.config.model_type
for t, l in zip(tokens, labels): for t, l in zip(tokens, labels):
text += t text += t
if l != 0: # Non punc. if l != 0: # Non punc.
text += self._punc_list[l] if is_fast_model:
text += self._punc_list[l - 1]
else:
text += self._punc_list[l]
return text return text
else: else:
raise NotImplementedError raise NotImplementedError
@ -160,14 +163,23 @@ class TextEngine(BaseEngine):
return False return False
self.executor = TextServerExecutor() self.executor = TextServerExecutor()
self.executor._init_from_path( if 'fast' in config.model_type:
task=config.task, self.executor._init_from_path_new(
model_type=config.model_type, task=config.task,
lang=config.lang, model_type=config.model_type,
cfg_path=config.cfg_path, lang=config.lang,
ckpt_path=config.ckpt_path, cfg_path=config.cfg_path,
vocab_file=config.vocab_file) ckpt_path=config.ckpt_path,
vocab_file=config.vocab_file)
else:
self.executor._init_from_path(
task=config.task,
model_type=config.model_type,
lang=config.lang,
cfg_path=config.cfg_path,
ckpt_path=config.ckpt_path,
vocab_file=config.vocab_file)
logger.info("Using model: %s." % (config.model_type))
logger.info("Initialize Text server engine successfully on device: %s." logger.info("Initialize Text server engine successfully on device: %s."
% (self.device)) % (self.device))
return True return True

Loading…
Cancel
Save