diff --git a/demos/streaming_asr_server/conf/application.yaml b/demos/streaming_asr_server/conf/application.yaml index a89d312ab..d446e13b6 100644 --- a/demos/streaming_asr_server/conf/application.yaml +++ b/demos/streaming_asr_server/conf/application.yaml @@ -21,7 +21,7 @@ engine_list: ['asr_online'] ################################### ASR ######################################### ################### speech task: asr; engine_type: online ####################### asr_online: - model_type: 'conformer_online_wenetspeech' + model_type: 'conformer_u2pp_online_wenetspeech' am_model: # the pdmodel file of am static model [optional] am_params: # the pdiparams file of am static model [optional] lang: 'zh' diff --git a/docs/source/released_model.md b/docs/source/released_model.md index d6691812e..bdac2c5bb 100644 --- a/docs/source/released_model.md +++ b/docs/source/released_model.md @@ -9,6 +9,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER | [Ds2 Online Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz) | Aishell Dataset | Char-based | 491 MB | 2 Conv + 5 LSTM layers | 0.0666 |-| 151 h | [D2 Online Aishell ASR0](../../examples/aishell/asr0) | onnx/inference/python | [Ds2 Offline Aishell ASR0 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_offline_aishell_ckpt_1.0.1.model.tar.gz)| Aishell Dataset | Char-based | 1.4 GB | 2 Conv + 5 bidirectional LSTM layers| 0.0554 |-| 151 h | [Ds2 Offline Aishell ASR0](../../examples/aishell/asr0) | inference/python | [Conformer Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz) | WenetSpeech Dataset | Char-based | 457 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.11 (test\_net) 0.1879 (test\_meeting) |-| 10000 h |- | python | +[Conformer U2PP Online Wenetspeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.1.model.tar.gz) | WenetSpeech Dataset | Char-based | 476 MB | Encoder:Conformer, Decoder:BiTransformer, Decoding method: Attention rescoring| 0.047198 (aishell test\_-1) 0.059212 (aishell test\_16) |-| 10000 h |- | python | [Conformer Online Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring| 0.0544 |-| 151 h | [Conformer Online Aishell ASR1](../../examples/aishell/asr1) | python | [Conformer Offline Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_1.0.1.model.tar.gz) | Aishell Dataset | Char-based | 189 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0460 |-| 151 h | [Conformer Offline Aishell ASR1](../../examples/aishell/asr1) | python | [Transformer Aishell ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_transformer_aishell_ckpt_0.1.1.model.tar.gz) | Aishell Dataset | Char-based | 128 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring | 0.0523 || 151 h | [Transformer Aishell ASR1](../../examples/aishell/asr1) | python | diff --git a/examples/wenetspeech/asr1/conf/tuning/chunk_decode.yaml b/examples/wenetspeech/asr1/conf/tuning/chunk_decode.yaml index 7e8afb7a8..6945ed6eb 100644 --- a/examples/wenetspeech/asr1/conf/tuning/chunk_decode.yaml +++ b/examples/wenetspeech/asr1/conf/tuning/chunk_decode.yaml @@ -1,11 +1,12 @@ beam_size: 10 -decode_batch_size: 128 -error_rate_type: cer decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. +reverse_weight: 0.3 # reverse weight for attention rescoring decode mode. decoding_chunk_size: 16 # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. # 0: used for training, it's prohibited here. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. simulate_streaming: True # simulate streaming inference. Defaults to False. +decode_batch_size: 128 +error_rate_type: cer diff --git a/examples/wenetspeech/asr1/conf/tuning/decode.yaml b/examples/wenetspeech/asr1/conf/tuning/decode.yaml index 6924bfa63..4015e9836 100644 --- a/examples/wenetspeech/asr1/conf/tuning/decode.yaml +++ b/examples/wenetspeech/asr1/conf/tuning/decode.yaml @@ -1,11 +1,12 @@ -decode_batch_size: 128 -error_rate_type: cer -decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' beam_size: 10 +decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. +reverse_weight: 0.3 # reverse weight for attention rescoring decode mode. decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. # 0: used for training, it's prohibited here. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. -simulate_streaming: False # simulate streaming inference. Defaults to False. \ No newline at end of file +simulate_streaming: False # simulate streaming inference. Defaults to False. +decode_batch_size: 128 +error_rate_type: cer diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 7296776f9..437f64631 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import io import os import sys import time @@ -51,7 +52,7 @@ class ASRExecutor(BaseExecutor): self.parser.add_argument( '--model', type=str, - default='conformer_wenetspeech', + default='conformer_u2pp_wenetspeech', choices=[ tag[:tag.index('-')] for tag in self.task_resource.pretrained_models.keys() @@ -229,6 +230,8 @@ class ASRExecutor(BaseExecutor): audio_file = input if isinstance(audio_file, (str, os.PathLike)): logger.debug("Preprocess audio_file:" + audio_file) + elif isinstance(audio_file, io.BytesIO): + audio_file.seek(0) # Get the object for feature extraction if "deepspeech2" in model_type or "conformer" in model_type or "transformer" in model_type: @@ -352,6 +355,8 @@ class ASRExecutor(BaseExecutor): if not os.path.isfile(audio_file): logger.error("Please input the right audio file path") return False + elif isinstance(audio_file, io.BytesIO): + audio_file.seek(0) logger.debug("checking the audio file format......") try: @@ -465,7 +470,7 @@ class ASRExecutor(BaseExecutor): @stats_wrapper def __call__(self, audio_file: os.PathLike, - model: str='conformer_wenetspeech', + model: str='conformer_u2pp_wenetspeech', lang: str='zh', sample_rate: int=16000, config: os.PathLike=None, diff --git a/paddlespeech/resource/model_alias.py b/paddlespeech/resource/model_alias.py index 85187a8d1..f5ec655b7 100644 --- a/paddlespeech/resource/model_alias.py +++ b/paddlespeech/resource/model_alias.py @@ -25,6 +25,8 @@ model_alias = { "deepspeech2online": ["paddlespeech.s2t.models.ds2:DeepSpeech2Model"], "conformer": ["paddlespeech.s2t.models.u2:U2Model"], "conformer_online": ["paddlespeech.s2t.models.u2:U2Model"], + "conformer_u2pp": ["paddlespeech.s2t.models.u2:U2Model"], + "conformer_u2pp_online": ["paddlespeech.s2t.models.u2:U2Model"], "transformer": ["paddlespeech.s2t.models.u2:U2Model"], "wenetspeech": ["paddlespeech.s2t.models.u2:U2Model"], diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index b6ab7f01c..0103651bc 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -68,6 +68,46 @@ asr_dynamic_pretrained_models = { '', }, }, + "conformer_u2pp_wenetspeech-zh-16k": { + '1.1': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.1.model.tar.gz', + 'md5': + 'eae678c04ed3b3f89672052fdc0c5e10', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/chunk_conformer_u2pp/checkpoints/avg_10', + 'model': + 'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams', + 'params': + 'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams', + 'lm_url': + '', + 'lm_md5': + '', + }, + }, + "conformer_u2pp_online_wenetspeech-zh-16k": { + '1.1': { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.1.2.model.tar.gz', + 'md5': + '925d047e9188dea7f421a718230c9ae3', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/chunk_conformer_u2pp/checkpoints/avg_10', + 'model': + 'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams', + 'params': + 'exp/chunk_conformer_u2pp/checkpoints/avg_10.pdparams', + 'lm_url': + '', + 'lm_md5': + '', + }, + }, "conformer_online_multicn-zh-16k": { '1.0': { 'url': diff --git a/paddlespeech/s2t/exps/u2/bin/quant.py b/paddlespeech/s2t/exps/u2/bin/quant.py index de7c27e79..907d79e5c 100644 --- a/paddlespeech/s2t/exps/u2/bin/quant.py +++ b/paddlespeech/s2t/exps/u2/bin/quant.py @@ -62,7 +62,6 @@ class U2Infer(): params_path = self.args.checkpoint_path + ".pdparams" model_dict = paddle.load(params_path) self.model.set_state_dict(model_dict) - logger.info(f"model_dict: {model_dict.keys()}") def run(self): check(args.audio_file) @@ -91,22 +90,22 @@ class U2Infer(): ctc_weight=decode_config.ctc_weight, decoding_chunk_size=decode_config.decoding_chunk_size, num_decoding_left_chunks=decode_config.num_decoding_left_chunks, - simulate_streaming=decode_config.simulate_streaming, - reverse_weight=self.reverse_weight) + simulate_streaming=decode_config.simulate_streaming + reverse_weight=decode_config.reverse_weight) rsl = result_transcripts[0][0] utt = Path(self.audio_file).name - logger.info(f"hyp: {utt} {result_transcripts[0][0]}") + logger.info(f"hyp: {utt} {rsl}") # print(self.model) # print(self.model.forward_encoder_chunk) - # return rsl - logger.info("-------------start export ----------------------") + logger.info("-------------start quant ----------------------") batch_size = 1 feat_dim = 80 model_size = 512 num_left_chunks = -1 + reverse_weight = 0.3 logger.info( - f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}" + f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}, reverse_weight {reverse_weight}" ) # ######################## self.model.forward_encoder_chunk ############ @@ -146,7 +145,6 @@ class U2Infer(): self.model.ctc_activation, input_spec=input_spec) ######################### self.model.forward_attention_decoder ######################## - reverse_weight = 0.3 input_spec = [ # hyps, (B, U) paddle.static.InputSpec(shape=[None, None], dtype='int64'), diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py index 31890cb19..2e067ab6b 100644 --- a/paddlespeech/s2t/exps/u2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py @@ -40,7 +40,6 @@ class U2Infer(): self.preprocess_conf = config.preprocess_config self.preprocess_args = {"train": False} self.preprocessing = Transformation(self.preprocess_conf) - self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0) self.text_feature = TextFeaturizer( unit_type=config.unit_type, vocab=config.vocab_filepath, @@ -90,7 +89,7 @@ class U2Infer(): decoding_chunk_size=decode_config.decoding_chunk_size, num_decoding_left_chunks=decode_config.num_decoding_left_chunks, simulate_streaming=decode_config.simulate_streaming, - reverse_weight=self.reverse_weight) + reverse_weight=decode_config.reverse_weight) rsl = result_transcripts[0][0] utt = Path(self.audio_file).name logger.info(f"hyp: {utt} {result_transcripts[0][0]}") diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 64b6c8df6..4208d389e 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -316,7 +316,6 @@ class U2Tester(U2Trainer): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) self.vocab_list = self.text_feature.vocab_list - self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0) def id2token(self, texts, texts_len, text_feature): """ ord() id to chr() chr """ @@ -351,8 +350,8 @@ class U2Tester(U2Trainer): ctc_weight=decode_config.ctc_weight, decoding_chunk_size=decode_config.decoding_chunk_size, num_decoding_left_chunks=decode_config.num_decoding_left_chunks, - simulate_streaming=decode_config.simulate_streaming, - reverse_weight=self.reverse_weight) + simulate_streaming=decode_config.simulate_streaming + reverse_weight=decode_config.reverse_weight) decode_time = time.time() - start_time for utt, target, result, rec_tids in zip( diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index c25c2186d..5cdcae06f 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -131,7 +131,7 @@ class U2BaseModel(ASRInterface, nn.Layer): if self.ctc_weight != 1.0: start = time.time() loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, - text, text_lengths) + text, text_lengths, self.reverse_weight) decoder_time = time.time() - start #logger.debug(f"decoder time: {decoder_time}") @@ -157,7 +157,8 @@ class U2BaseModel(ASRInterface, nn.Layer): encoder_out: paddle.Tensor, encoder_mask: paddle.Tensor, ys_pad: paddle.Tensor, - ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]: + ys_pad_lens: paddle.Tensor, + reverse_weight: float) -> Tuple[paddle.Tensor, float]: """Calc attention loss. Args: @@ -165,6 +166,7 @@ class U2BaseModel(ASRInterface, nn.Layer): encoder_mask (paddle.Tensor): [B, 1, Tmax] ys_pad (paddle.Tensor): [B, Umax] ys_pad_lens (paddle.Tensor): [B] + reverse_weight (float): reverse decoder weight. Returns: Tuple[paddle.Tensor, float]: attention_loss, accuracy rate @@ -179,15 +181,15 @@ class U2BaseModel(ASRInterface, nn.Layer): # 1. Forward decoder decoder_out, r_decoder_out, _ = self.decoder( encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad, - self.reverse_weight) + reverse_weight) # 2. Compute attention loss loss_att = self.criterion_att(decoder_out, ys_out_pad) r_loss_att = paddle.to_tensor(0.0) - if self.reverse_weight > 0.0: + if reverse_weight > 0.0: r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) - loss_att = loss_att * (1 - self.reverse_weight - ) + r_loss_att * self.reverse_weight + loss_att = loss_att * (1 - reverse_weight + ) + r_loss_att * reverse_weight acc_att = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, @@ -501,16 +503,15 @@ class U2BaseModel(ASRInterface, nn.Layer): num_decoding_left_chunks, simulate_streaming) return hyps[0][0] - def attention_rescoring( - self, - speech: paddle.Tensor, - speech_lengths: paddle.Tensor, - beam_size: int, - decoding_chunk_size: int=-1, - num_decoding_left_chunks: int=-1, - ctc_weight: float=0.0, - simulate_streaming: bool=False, - reverse_weight: float=0.0, ) -> List[int]: + def attention_rescoring(self, + speech: paddle.Tensor, + speech_lengths: paddle.Tensor, + beam_size: int, + decoding_chunk_size: int=-1, + num_decoding_left_chunks: int=-1, + ctc_weight: float=0.0, + simulate_streaming: bool=False, + reverse_weight: float=0.0) -> List[int]: """ Apply attention rescoring decoding, CTC prefix beam search is applied first to get nbest, then we resoring the nbest on attention decoder with corresponding encoder out @@ -525,6 +526,7 @@ class U2BaseModel(ASRInterface, nn.Layer): 0: used for training, it's prohibited here simulate_streaming (bool): whether do encoder forward in a streaming fashion + reverse_weight (float): reverse deocder weight. Returns: List[int]: Attention rescoring result """ @@ -554,14 +556,13 @@ class U2BaseModel(ASRInterface, nn.Layer): hyp_content, place=device, dtype=paddle.long) hyp_list.append(hyp_content) hyps_pad = pad_sequence(hyp_list, True, self.ignore_id) - ori_hyps_pad = hyps_pad hyps_lens = paddle.to_tensor( [len(hyp[0]) for hyp in hyps], place=device, dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining logger.debug( f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}") - hyps_lens = hyps_lens + 1 # Add at begining # ctc score in ln domain # (beam_size, max_hyps_len, vocab_size) @@ -598,8 +599,8 @@ class U2BaseModel(ASRInterface, nn.Layer): f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}" ) - score = score * (1 - reverse_weight) + r_score * reverse_weight - + score = score * (1 - reverse_weight + ) + r_score * reverse_weight # add ctc score (which in ln domain) score += hyp[1] * ctc_weight if score > best_score: @@ -769,6 +770,7 @@ class U2BaseModel(ASRInterface, nn.Layer): num_decoding_left_chunks (int, optional): number of left chunks for decoding. Defaults to -1. simulate_streaming (bool, optional): simulate streaming inference. Defaults to False. + reverse_weight (float, optional): reverse decoder weight, used by `attention_rescoring`. Raises: ValueError: when not support decoding_method. diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml index d72eb2379..d5357c853 100644 --- a/paddlespeech/server/conf/ws_conformer_application.yaml +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -30,7 +30,7 @@ asr_online: decode_method: num_decoding_left_chunks: -1 force_yes: True - device: # cpu or gpu:id + device: cpu # cpu or gpu:id continuous_decoding: True # enable continue decoding when endpoint detected am_predictor_conf: diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index ae0260929..27eda7ef6 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -22,6 +22,7 @@ from numpy import float32 from yacs.config import CfgNode from paddlespeech.audio.transform.transformation import Transformation +from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource @@ -602,6 +603,7 @@ class PaddleASRConnectionHanddler: hyps_pad = pad_sequence( hyp_list, batch_first=True, padding_value=self.model.ignore_id) + ori_hyps_pad = hyps_pad hyps_lens = paddle.to_tensor( [len(hyp[0]) for hyp in hyps], place=self.device, dtype=paddle.long) # (beam_size,) @@ -609,16 +611,15 @@ class PaddleASRConnectionHanddler: self.model.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining - encoder_out = self.encoder_out.repeat(beam_size, 1, 1) - encoder_mask = paddle.ones( - (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) - - decoder_out, _, _ = self.model.decoder( - encoder_out, encoder_mask, hyps_pad, - hyps_lens) # (beam_size, max_hyps_len, vocab_size) # ctc score in ln domain - decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) + # (beam_size, max_hyps_len, vocab_size) + decoder_out, r_decoder_out = self.model.forward_attention_decoder( + hyps_pad, hyps_lens, self.encoder_out, self.model.reverse_weight) + decoder_out = decoder_out.numpy() + # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a + # conventional transformer decoder. + r_decoder_out = r_decoder_out.numpy() # Only use decoder score for rescoring best_score = -float('inf') @@ -631,6 +632,13 @@ class PaddleASRConnectionHanddler: # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.model.eos] + if self.model.reverse_weight > 0: + r_score = 0.0 + for j, w in enumerate(hyp[0]): + r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w] + r_score += r_decoder_out[i][len(hyp[0])][self.model.eos] + score = score * (1 - self.model.reverse_weight + ) + r_score * self.model.reverse_weight # add ctc score (which in ln domain) score += hyp[1] * self.ctc_decode_config.ctc_weight diff --git a/paddlespeech/server/engine/text/python/text_engine.py b/paddlespeech/server/engine/text/python/text_engine.py index 6167e7784..cc72c0543 100644 --- a/paddlespeech/server/engine/text/python/text_engine.py +++ b/paddlespeech/server/engine/text/python/text_engine.py @@ -107,11 +107,14 @@ class PaddleTextConnectionHandler: assert len(tokens) == len(labels) text = '' + is_fast_model = 'fast' in self.text_engine.config.model_type for t, l in zip(tokens, labels): text += t if l != 0: # Non punc. - text += self._punc_list[l] - + if is_fast_model: + text += self._punc_list[l - 1] + else: + text += self._punc_list[l] return text else: raise NotImplementedError @@ -160,14 +163,23 @@ class TextEngine(BaseEngine): return False self.executor = TextServerExecutor() - self.executor._init_from_path( - task=config.task, - model_type=config.model_type, - lang=config.lang, - cfg_path=config.cfg_path, - ckpt_path=config.ckpt_path, - vocab_file=config.vocab_file) - + if 'fast' in config.model_type: + self.executor._init_from_path_new( + task=config.task, + model_type=config.model_type, + lang=config.lang, + cfg_path=config.cfg_path, + ckpt_path=config.ckpt_path, + vocab_file=config.vocab_file) + else: + self.executor._init_from_path( + task=config.task, + model_type=config.model_type, + lang=config.lang, + cfg_path=config.cfg_path, + ckpt_path=config.ckpt_path, + vocab_file=config.vocab_file) + logger.info("Using model: %s." % (config.model_type)) logger.info("Initialize Text server engine successfully on device: %s." % (self.device)) return True