|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
from typing import ByteString
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
@ -30,9 +31,10 @@ from paddlespeech.s2t.transform.transformation import Transformation
|
|
|
|
|
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
|
|
|
|
|
from paddlespeech.s2t.utils.tensor_utils import pad_sequence
|
|
|
|
|
from paddlespeech.s2t.utils.utility import UpdateConfig
|
|
|
|
|
from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoingOpt
|
|
|
|
|
from paddlespeech.server.engine.asr.online.ctc_endpoint import OnlineCTCEndpoint
|
|
|
|
|
from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch
|
|
|
|
|
from paddlespeech.server.engine.base_engine import BaseEngine
|
|
|
|
|
from paddlespeech.server.utils.audio_process import pcm2float
|
|
|
|
|
from paddlespeech.server.utils.paddle_predictor import init_predictor
|
|
|
|
|
|
|
|
|
|
__all__ = ['PaddleASRConnectionHanddler', 'ASRServerExecutor', 'ASREngine']
|
|
|
|
@ -54,24 +56,33 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
self.model_config = asr_engine.executor.config
|
|
|
|
|
self.asr_engine = asr_engine
|
|
|
|
|
|
|
|
|
|
self.init()
|
|
|
|
|
self.reset()
|
|
|
|
|
|
|
|
|
|
def init(self):
|
|
|
|
|
# model_type, sample_rate and text_feature is shared for deepspeech2 and conformer
|
|
|
|
|
self.model_type = self.asr_engine.executor.model_type
|
|
|
|
|
self.sample_rate = self.asr_engine.executor.sample_rate
|
|
|
|
|
# tokens to text
|
|
|
|
|
self.text_feature = self.asr_engine.executor.text_feature
|
|
|
|
|
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
self.am_predictor = self.asr_engine.executor.am_predictor
|
|
|
|
|
|
|
|
|
|
# extract feat, new only fbank in conformer model
|
|
|
|
|
self.preprocess_conf = self.model_config.preprocess_config
|
|
|
|
|
self.preprocess_args = {"train": False}
|
|
|
|
|
self.preprocessing = Transformation(self.preprocess_conf)
|
|
|
|
|
|
|
|
|
|
# frame window and frame shift, in samples unit
|
|
|
|
|
self.win_length = self.preprocess_conf.process[0]['win_length']
|
|
|
|
|
self.n_shift = self.preprocess_conf.process[0]['n_shift']
|
|
|
|
|
|
|
|
|
|
assert self.preprocess_conf.process[0]['fs'] == self.sample_rate, (
|
|
|
|
|
self.sample_rate, self.preprocess_conf.process[0]['fs'])
|
|
|
|
|
self.frame_shift_in_ms = int(
|
|
|
|
|
self.n_shift / self.preprocess_conf.process[0]['fs'] * 1000)
|
|
|
|
|
|
|
|
|
|
self.init_decoder()
|
|
|
|
|
self.reset()
|
|
|
|
|
|
|
|
|
|
def init_decoder(self):
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
self.am_predictor = self.asr_engine.executor.am_predictor
|
|
|
|
|
|
|
|
|
|
self.decoder = CTCDecoder(
|
|
|
|
|
odim=self.model_config.output_dim, # <blank> is in vocab
|
|
|
|
|
enc_n_units=self.model_config.rnn_layer_size * 2,
|
|
|
|
@ -90,10 +101,6 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
|
|
|
|
|
cfg.num_proc_bsearch)
|
|
|
|
|
|
|
|
|
|
# frame window and frame shift, in samples unit
|
|
|
|
|
self.win_length = self.preprocess_conf.process[0]['win_length']
|
|
|
|
|
self.n_shift = self.preprocess_conf.process[0]['n_shift']
|
|
|
|
|
|
|
|
|
|
elif "conformer" in self.model_type or "transformer" in self.model_type:
|
|
|
|
|
# acoustic model
|
|
|
|
|
self.model = self.asr_engine.executor.model
|
|
|
|
@ -102,68 +109,88 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
self.ctc_decode_config = self.asr_engine.executor.config.decode
|
|
|
|
|
self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config)
|
|
|
|
|
|
|
|
|
|
# extract feat, new only fbank in conformer model
|
|
|
|
|
self.preprocess_conf = self.model_config.preprocess_config
|
|
|
|
|
self.preprocess_args = {"train": False}
|
|
|
|
|
self.preprocessing = Transformation(self.preprocess_conf)
|
|
|
|
|
|
|
|
|
|
# frame window and frame shift, in samples unit
|
|
|
|
|
self.win_length = self.preprocess_conf.process[0]['win_length']
|
|
|
|
|
self.n_shift = self.preprocess_conf.process[0]['n_shift']
|
|
|
|
|
# ctc endpoint
|
|
|
|
|
self.endpoint_opt = OnlineCTCEndpoingOpt(
|
|
|
|
|
frame_shift_in_ms=self.frame_shift_in_ms, blank=0)
|
|
|
|
|
self.endpointer = OnlineCTCEndpoint(self.endpoint_opt)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Not supported: {self.model_type}")
|
|
|
|
|
|
|
|
|
|
def extract_feat(self, samples):
|
|
|
|
|
# we compute the elapsed time of first char occuring
|
|
|
|
|
# and we record the start time at the first pcm sample arraving
|
|
|
|
|
def model_reset(self):
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if "deepspeech2online" in self.model_type:
|
|
|
|
|
# self.reamined_wav stores all the samples,
|
|
|
|
|
# include the original remained_wav and this package samples
|
|
|
|
|
samples = np.frombuffer(samples, dtype=np.int16)
|
|
|
|
|
assert samples.ndim == 1
|
|
|
|
|
# feature cache
|
|
|
|
|
self.cached_feat = None
|
|
|
|
|
|
|
|
|
|
if self.remained_wav is None:
|
|
|
|
|
self.remained_wav = samples
|
|
|
|
|
else:
|
|
|
|
|
assert self.remained_wav.ndim == 1
|
|
|
|
|
self.remained_wav = np.concatenate([self.remained_wav, samples])
|
|
|
|
|
logger.info(
|
|
|
|
|
f"The connection remain the audio samples: {self.remained_wav.shape}"
|
|
|
|
|
)
|
|
|
|
|
## conformer
|
|
|
|
|
# cache for conformer online
|
|
|
|
|
self.subsampling_cache = None
|
|
|
|
|
self.elayers_output_cache = None
|
|
|
|
|
self.conformer_cnn_cache = None
|
|
|
|
|
self.encoder_out = None
|
|
|
|
|
# conformer decoding state
|
|
|
|
|
self.offset = 0 # global offset in decoding frame unit
|
|
|
|
|
|
|
|
|
|
# fbank
|
|
|
|
|
feat = self.preprocessing(self.remained_wav,
|
|
|
|
|
**self.preprocess_args)
|
|
|
|
|
feat = paddle.to_tensor(
|
|
|
|
|
feat, dtype="float32").unsqueeze(axis=0)
|
|
|
|
|
## just for record info
|
|
|
|
|
self.chunk_num = 0 # global decoding chunk num, not used
|
|
|
|
|
|
|
|
|
|
if self.cached_feat is None:
|
|
|
|
|
self.cached_feat = feat
|
|
|
|
|
else:
|
|
|
|
|
assert (len(feat.shape) == 3)
|
|
|
|
|
assert (len(self.cached_feat.shape) == 3)
|
|
|
|
|
self.cached_feat = paddle.concat(
|
|
|
|
|
[self.cached_feat, feat], axis=1)
|
|
|
|
|
def reset_continuous_decoding(self):
|
|
|
|
|
"""
|
|
|
|
|
when in continous decoding, reset for next utterance.
|
|
|
|
|
"""
|
|
|
|
|
self.global_frame_offset = self.num_frames
|
|
|
|
|
self.model_reset()
|
|
|
|
|
self.searcher.reset()
|
|
|
|
|
self.endpointer.reset()
|
|
|
|
|
|
|
|
|
|
# set the feat device
|
|
|
|
|
if self.device is None:
|
|
|
|
|
self.device = self.cached_feat.place
|
|
|
|
|
def reset(self):
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
# for deepspeech2
|
|
|
|
|
# init state
|
|
|
|
|
self.chunk_state_h_box = np.zeros(
|
|
|
|
|
(self.model_config.num_rnn_layers, 1,
|
|
|
|
|
self.model_config.rnn_layer_size),
|
|
|
|
|
dtype=float32)
|
|
|
|
|
self.chunk_state_c_box = np.zeros(
|
|
|
|
|
(self.model_config.num_rnn_layers, 1,
|
|
|
|
|
self.model_config.rnn_layer_size),
|
|
|
|
|
dtype=float32)
|
|
|
|
|
self.decoder.reset_decoder(batch_size=1)
|
|
|
|
|
|
|
|
|
|
# cur frame step
|
|
|
|
|
num_frames = feat.shape[1]
|
|
|
|
|
if "conformer" in self.model_type or "transformer" in self.model_type:
|
|
|
|
|
self.searcher.reset()
|
|
|
|
|
self.endpointer.reset()
|
|
|
|
|
|
|
|
|
|
self.num_frames += num_frames
|
|
|
|
|
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
|
|
|
|
|
self.device = None
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}"
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}"
|
|
|
|
|
)
|
|
|
|
|
## common
|
|
|
|
|
# global sample and frame step
|
|
|
|
|
self.num_samples = 0
|
|
|
|
|
self.global_frame_offset = 0
|
|
|
|
|
# frame step of cur utterance
|
|
|
|
|
self.num_frames = 0
|
|
|
|
|
|
|
|
|
|
# cache for audio and feat
|
|
|
|
|
self.remained_wav = None
|
|
|
|
|
self.cached_feat = None
|
|
|
|
|
|
|
|
|
|
## conformer
|
|
|
|
|
self.model_reset()
|
|
|
|
|
|
|
|
|
|
## outputs
|
|
|
|
|
# partial/ending decoding results
|
|
|
|
|
self.result_transcripts = ['']
|
|
|
|
|
# token timestamp result
|
|
|
|
|
self.word_time_stamp = []
|
|
|
|
|
|
|
|
|
|
## just for record
|
|
|
|
|
self.hyps = []
|
|
|
|
|
|
|
|
|
|
# one best timestamp viterbi prob is large.
|
|
|
|
|
self.time_stamp = []
|
|
|
|
|
|
|
|
|
|
elif "conformer_online" in self.model_type:
|
|
|
|
|
def extract_feat(self, samples: ByteString):
|
|
|
|
|
logger.info("Online ASR extract the feat")
|
|
|
|
|
samples = np.frombuffer(samples, dtype=np.int16)
|
|
|
|
|
assert samples.ndim == 1
|
|
|
|
@ -189,10 +216,8 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
# fbank
|
|
|
|
|
x_chunk = self.preprocessing(self.remained_wav,
|
|
|
|
|
**self.preprocess_args)
|
|
|
|
|
x_chunk = paddle.to_tensor(
|
|
|
|
|
x_chunk, dtype="float32").unsqueeze(axis=0)
|
|
|
|
|
x_chunk = self.preprocessing(self.remained_wav, **self.preprocess_args)
|
|
|
|
|
x_chunk = paddle.to_tensor(x_chunk, dtype="float32").unsqueeze(axis=0)
|
|
|
|
|
|
|
|
|
|
# feature cache
|
|
|
|
|
if self.cached_feat is None:
|
|
|
|
@ -224,55 +249,6 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"global samples: {self.num_samples}")
|
|
|
|
|
logger.info(f"global frames: {self.num_frames}")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"not supported: {self.model_type}")
|
|
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
# for deepspeech2
|
|
|
|
|
# init state
|
|
|
|
|
self.chunk_state_h_box = np.zeros(
|
|
|
|
|
(self.model_config.num_rnn_layers, 1,
|
|
|
|
|
self.model_config.rnn_layer_size),
|
|
|
|
|
dtype=float32)
|
|
|
|
|
self.chunk_state_c_box = np.zeros(
|
|
|
|
|
(self.model_config.num_rnn_layers, 1,
|
|
|
|
|
self.model_config.rnn_layer_size),
|
|
|
|
|
dtype=float32)
|
|
|
|
|
self.decoder.reset_decoder(batch_size=1)
|
|
|
|
|
|
|
|
|
|
self.device = None
|
|
|
|
|
|
|
|
|
|
## common
|
|
|
|
|
|
|
|
|
|
# global sample and frame step
|
|
|
|
|
self.num_samples = 0
|
|
|
|
|
self.num_frames = 0
|
|
|
|
|
|
|
|
|
|
# cache for audio and feat
|
|
|
|
|
self.remained_wav = None
|
|
|
|
|
self.cached_feat = None
|
|
|
|
|
|
|
|
|
|
# partial/ending decoding results
|
|
|
|
|
self.result_transcripts = ['']
|
|
|
|
|
|
|
|
|
|
## conformer
|
|
|
|
|
|
|
|
|
|
# cache for conformer online
|
|
|
|
|
self.subsampling_cache = None
|
|
|
|
|
self.elayers_output_cache = None
|
|
|
|
|
self.conformer_cnn_cache = None
|
|
|
|
|
self.encoder_out = None
|
|
|
|
|
# conformer decoding state
|
|
|
|
|
self.chunk_num = 0 # globa decoding chunk num
|
|
|
|
|
self.offset = 0 # global offset in decoding frame unit
|
|
|
|
|
self.hyps = []
|
|
|
|
|
|
|
|
|
|
# token timestamp result
|
|
|
|
|
self.word_time_stamp = []
|
|
|
|
|
|
|
|
|
|
# one best timestamp viterbi prob is large.
|
|
|
|
|
self.time_stamp = []
|
|
|
|
|
|
|
|
|
|
def decode(self, is_finished=False):
|
|
|
|
|
"""advance decoding
|
|
|
|
@ -280,14 +256,12 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
Args:
|
|
|
|
|
is_finished (bool, optional): Is last frame or not. Defaults to False.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
Exception: when not support model.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None: nothing
|
|
|
|
|
None:
|
|
|
|
|
"""
|
|
|
|
|
if "deepspeech2online" in self.model_type:
|
|
|
|
|
if "deepspeech2" in self.model_type:
|
|
|
|
|
decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
|
|
|
|
|
|
|
|
|
|
context = 7 # context=7, in audio frame unit
|
|
|
|
|
subsampling = 4 # subsampling=4, in audio frame unit
|
|
|
|
|
|
|
|
|
@ -332,9 +306,11 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
end = None
|
|
|
|
|
for cur in range(0, num_frames - left_frames + 1, stride):
|
|
|
|
|
end = min(cur + decoding_window, num_frames)
|
|
|
|
|
|
|
|
|
|
# extract the audio
|
|
|
|
|
x_chunk = self.cached_feat[:, cur:end, :].numpy()
|
|
|
|
|
x_chunk_lens = np.array([x_chunk.shape[1]])
|
|
|
|
|
|
|
|
|
|
trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens)
|
|
|
|
|
|
|
|
|
|
self.result_transcripts = [trans_best]
|
|
|
|
@ -409,31 +385,38 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def advance_decoding(self, is_finished=False):
|
|
|
|
|
if "deepspeech" in self.model_type:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
"Conformer/Transformer: start to decode with advanced_decoding method"
|
|
|
|
|
)
|
|
|
|
|
cfg = self.ctc_decode_config
|
|
|
|
|
|
|
|
|
|
# cur chunk size, in decoding frame unit
|
|
|
|
|
# cur chunk size, in decoding frame unit, e.g. 16
|
|
|
|
|
decoding_chunk_size = cfg.decoding_chunk_size
|
|
|
|
|
# using num of history chunks
|
|
|
|
|
# using num of history chunks, e.g -1
|
|
|
|
|
num_decoding_left_chunks = cfg.num_decoding_left_chunks
|
|
|
|
|
assert decoding_chunk_size > 0
|
|
|
|
|
|
|
|
|
|
# e.g. 4
|
|
|
|
|
subsampling = self.model.encoder.embed.subsampling_rate
|
|
|
|
|
# e.g. 7
|
|
|
|
|
context = self.model.encoder.embed.right_context + 1
|
|
|
|
|
|
|
|
|
|
# processed chunk feature cached for next chunk
|
|
|
|
|
# processed chunk feature cached for next chunk, e.g. 3
|
|
|
|
|
cached_feature_num = context - subsampling
|
|
|
|
|
# decoding stride, in audio frame unit
|
|
|
|
|
stride = subsampling * decoding_chunk_size
|
|
|
|
|
|
|
|
|
|
# decoding window, in audio frame unit
|
|
|
|
|
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
|
|
|
|
# decoding stride, in audio frame unit
|
|
|
|
|
stride = subsampling * decoding_chunk_size
|
|
|
|
|
|
|
|
|
|
if self.cached_feat is None:
|
|
|
|
|
logger.info("no audio feat, please input more pcm data")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# (B=1,T,D)
|
|
|
|
|
num_frames = self.cached_feat.shape[1]
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames"
|
|
|
|
@ -454,9 +437,6 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
logger.info("start to do model forward")
|
|
|
|
|
# hist of chunks, in deocding frame unit
|
|
|
|
|
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
|
|
|
|
|
outputs = []
|
|
|
|
|
|
|
|
|
|
# num_frames - context + 1 ensure that current frame can get context window
|
|
|
|
|
if is_finished:
|
|
|
|
@ -466,7 +446,11 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
# we only process decoding_window frames for one chunk
|
|
|
|
|
left_frames = decoding_window
|
|
|
|
|
|
|
|
|
|
# hist of chunks, in deocding frame unit
|
|
|
|
|
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
|
|
|
|
|
|
|
|
|
|
# record the end for removing the processed feat
|
|
|
|
|
outputs = []
|
|
|
|
|
end = None
|
|
|
|
|
for cur in range(0, num_frames - left_frames + 1, stride):
|
|
|
|
|
end = min(cur + decoding_window, num_frames)
|
|
|
|
@ -491,30 +475,28 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
self.encoder_out = ys
|
|
|
|
|
else:
|
|
|
|
|
self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"This connection handler encoder out shape: {self.encoder_out.shape}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# get the ctc probs
|
|
|
|
|
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
|
|
|
|
|
ctc_probs = ctc_probs.squeeze(0)
|
|
|
|
|
|
|
|
|
|
## decoding
|
|
|
|
|
# advance decoding
|
|
|
|
|
self.searcher.search(ctc_probs, self.cached_feat.place)
|
|
|
|
|
# get one best hyps
|
|
|
|
|
self.hyps = self.searcher.get_one_best_hyps()
|
|
|
|
|
|
|
|
|
|
assert self.cached_feat.shape[0] == 1
|
|
|
|
|
assert end >= cached_feature_num
|
|
|
|
|
|
|
|
|
|
# advance cache of feat
|
|
|
|
|
self.cached_feat = self.cached_feat[0, end -
|
|
|
|
|
cached_feature_num:, :].unsqueeze(0)
|
|
|
|
|
assert self.cached_feat.shape[0] == 1 #(B=1,T,D)
|
|
|
|
|
assert end >= cached_feature_num
|
|
|
|
|
self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :]
|
|
|
|
|
assert len(
|
|
|
|
|
self.cached_feat.shape
|
|
|
|
|
) == 3, f"current cache feat shape is: {self.cached_feat.shape}"
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"This connection handler encoder out shape: {self.encoder_out.shape}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def update_result(self):
|
|
|
|
|
"""Conformer/Transformer hyps to result.
|
|
|
|
|
"""
|
|
|
|
@ -654,24 +636,28 @@ class PaddleASRConnectionHanddler:
|
|
|
|
|
|
|
|
|
|
# update each word start and end time stamp
|
|
|
|
|
# decoding frame to audio frame
|
|
|
|
|
frame_shift = self.model.encoder.embed.subsampling_rate
|
|
|
|
|
frame_shift_in_sec = frame_shift * (self.n_shift / self.sample_rate)
|
|
|
|
|
logger.info(f"frame shift sec: {frame_shift_in_sec}")
|
|
|
|
|
decode_frame_shift = self.model.encoder.embed.subsampling_rate
|
|
|
|
|
decode_frame_shift_in_sec = decode_frame_shift * (self.n_shift /
|
|
|
|
|
self.sample_rate)
|
|
|
|
|
logger.info(f"decode frame shift in sec: {decode_frame_shift_in_sec}")
|
|
|
|
|
|
|
|
|
|
global_offset_in_sec = self.global_frame_offset * self.frame_shift_in_ms / 1000.0
|
|
|
|
|
logger.info(f"global offset: {global_offset_in_sec} sec.")
|
|
|
|
|
|
|
|
|
|
word_time_stamp = []
|
|
|
|
|
for idx, _ in enumerate(self.time_stamp):
|
|
|
|
|
start = (self.time_stamp[idx - 1] + self.time_stamp[idx]
|
|
|
|
|
) / 2.0 if idx > 0 else 0
|
|
|
|
|
start = start * frame_shift_in_sec
|
|
|
|
|
start = start * decode_frame_shift_in_sec
|
|
|
|
|
|
|
|
|
|
end = (self.time_stamp[idx] + self.time_stamp[idx + 1]
|
|
|
|
|
) / 2.0 if idx < len(self.time_stamp) - 1 else self.offset
|
|
|
|
|
|
|
|
|
|
end = end * frame_shift_in_sec
|
|
|
|
|
end = end * decode_frame_shift_in_sec
|
|
|
|
|
word_time_stamp.append({
|
|
|
|
|
"w": self.result_transcripts[0][idx],
|
|
|
|
|
"bg": start,
|
|
|
|
|
"ed": end
|
|
|
|
|
"bg": global_offset_in_sec + start,
|
|
|
|
|
"ed": global_offset_in_sec + end
|
|
|
|
|
})
|
|
|
|
|
# logger.info(f"{word_time_stamp[-1]}")
|
|
|
|
|
|
|
|
|
@ -705,13 +691,14 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
|
|
|
|
|
self.model_type = model_type
|
|
|
|
|
self.sample_rate = sample_rate
|
|
|
|
|
logger.info(f"model_type: {self.model_type}")
|
|
|
|
|
|
|
|
|
|
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
|
|
|
|
|
tag = model_type + '-' + lang + '-' + sample_rate_str
|
|
|
|
|
self.task_resource.set_task_model(model_tag=tag)
|
|
|
|
|
|
|
|
|
|
if cfg_path is None or am_model is None or am_params is None:
|
|
|
|
|
logger.info(f"Load the pretrained model, tag = {tag}")
|
|
|
|
|
self.res_path = self.task_resource.res_dir
|
|
|
|
|
|
|
|
|
|
self.cfg_path = os.path.join(
|
|
|
|
|
self.res_path, self.task_resource.res_dict['cfg_path'])
|
|
|
|
|
|
|
|
|
@ -719,7 +706,6 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
self.task_resource.res_dict['model'])
|
|
|
|
|
self.am_params = os.path.join(self.res_path,
|
|
|
|
|
self.task_resource.res_dict['params'])
|
|
|
|
|
logger.info(self.res_path)
|
|
|
|
|
else:
|
|
|
|
|
self.cfg_path = os.path.abspath(cfg_path)
|
|
|
|
|
self.am_model = os.path.abspath(am_model)
|
|
|
|
@ -727,9 +713,12 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
self.res_path = os.path.dirname(
|
|
|
|
|
os.path.dirname(os.path.abspath(self.cfg_path)))
|
|
|
|
|
|
|
|
|
|
logger.info(self.cfg_path)
|
|
|
|
|
logger.info(self.am_model)
|
|
|
|
|
logger.info(self.am_params)
|
|
|
|
|
logger.info("Load the pretrained model:")
|
|
|
|
|
logger.info(f" tag = {tag}")
|
|
|
|
|
logger.info(f" res_path: {self.res_path}")
|
|
|
|
|
logger.info(f" cfg path: {self.cfg_path}")
|
|
|
|
|
logger.info(f" am_model path: {self.am_model}")
|
|
|
|
|
logger.info(f" am_params path: {self.am_params}")
|
|
|
|
|
|
|
|
|
|
#Init body.
|
|
|
|
|
self.config = CfgNode(new_allowed=True)
|
|
|
|
@ -738,13 +727,18 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
if self.config.spm_model_prefix:
|
|
|
|
|
self.config.spm_model_prefix = os.path.join(
|
|
|
|
|
self.res_path, self.config.spm_model_prefix)
|
|
|
|
|
logger.info(f"spm model path: {self.config.spm_model_prefix}")
|
|
|
|
|
|
|
|
|
|
self.vocab = self.config.vocab_filepath
|
|
|
|
|
|
|
|
|
|
self.text_feature = TextFeaturizer(
|
|
|
|
|
unit_type=self.config.unit_type,
|
|
|
|
|
vocab=self.config.vocab_filepath,
|
|
|
|
|
spm_model_prefix=self.config.spm_model_prefix)
|
|
|
|
|
self.vocab = self.config.vocab_filepath
|
|
|
|
|
with UpdateConfig(self.config):
|
|
|
|
|
|
|
|
|
|
if "deepspeech2" in model_type:
|
|
|
|
|
with UpdateConfig(self.config):
|
|
|
|
|
# download lm
|
|
|
|
|
self.config.decode.lang_model_path = os.path.join(
|
|
|
|
|
MODEL_HOME, 'language_model',
|
|
|
|
|
self.config.decode.lang_model_path)
|
|
|
|
@ -756,7 +750,16 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
lm_url,
|
|
|
|
|
os.path.dirname(self.config.decode.lang_model_path), lm_md5)
|
|
|
|
|
|
|
|
|
|
# AM predictor
|
|
|
|
|
logger.info("ASR engine start to init the am predictor")
|
|
|
|
|
self.am_predictor_conf = am_predictor_conf
|
|
|
|
|
self.am_predictor = init_predictor(
|
|
|
|
|
model_file=self.am_model,
|
|
|
|
|
params_file=self.am_params,
|
|
|
|
|
predictor_conf=self.am_predictor_conf)
|
|
|
|
|
|
|
|
|
|
elif "conformer" in model_type or "transformer" in model_type:
|
|
|
|
|
with UpdateConfig(self.config):
|
|
|
|
|
logger.info("start to create the stream conformer asr engine")
|
|
|
|
|
# update the decoding method
|
|
|
|
|
if decode_method:
|
|
|
|
@ -770,37 +773,24 @@ class ASRServerExecutor(ASRExecutor):
|
|
|
|
|
logger.info(
|
|
|
|
|
"we set the decoding_method to attention_rescoring")
|
|
|
|
|
self.config.decode.decoding_method = "attention_rescoring"
|
|
|
|
|
|
|
|
|
|
assert self.config.decode.decoding_method in [
|
|
|
|
|
"ctc_prefix_beam_search", "attention_rescoring"
|
|
|
|
|
], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}"
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("wrong type")
|
|
|
|
|
|
|
|
|
|
if "deepspeech2" in model_type:
|
|
|
|
|
# AM predictor
|
|
|
|
|
logger.info("ASR engine start to init the am predictor")
|
|
|
|
|
self.am_predictor_conf = am_predictor_conf
|
|
|
|
|
self.am_predictor = init_predictor(
|
|
|
|
|
model_file=self.am_model,
|
|
|
|
|
params_file=self.am_params,
|
|
|
|
|
predictor_conf=self.am_predictor_conf)
|
|
|
|
|
elif "conformer" in model_type or "transformer" in model_type:
|
|
|
|
|
# load model
|
|
|
|
|
model_name = model_type[:model_type.rindex(
|
|
|
|
|
'_')] # model_type: {model_name}_{dataset}
|
|
|
|
|
logger.info(f"model name: {model_name}")
|
|
|
|
|
model_class = self.task_resource.get_model_class(model_name)
|
|
|
|
|
model_conf = self.config
|
|
|
|
|
model = model_class.from_config(model_conf)
|
|
|
|
|
model = model_class.from_config(self.config)
|
|
|
|
|
self.model = model
|
|
|
|
|
self.model.set_state_dict(paddle.load(self.am_model))
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
|
# load model
|
|
|
|
|
model_dict = paddle.load(self.am_model)
|
|
|
|
|
self.model.set_state_dict(model_dict)
|
|
|
|
|
logger.info("create the transformer like model success")
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Not support: {model_type}")
|
|
|
|
|
raise Exception(f"not support: {model_type}")
|
|
|
|
|
|
|
|
|
|
logger.info(f"create the {model_type} model success")
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|