diff --git a/examples/csmsc/tts0/local/inference.sh b/examples/csmsc/tts0/local/inference.sh index e417d748..d2960441 100755 --- a/examples/csmsc/tts0/local/inference.sh +++ b/examples/csmsc/tts0/local/inference.sh @@ -27,20 +27,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --phones_dict=dump/phone_id_map.txt fi -# style melgan -# style melgan's Dygraph to Static Graph is not ready now -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - python3 ${BIN_DIR}/../inference.py \ - --inference_dir=${train_output_path}/inference \ - --am=tacotron2_csmsc \ - --voc=style_melgan_csmsc \ - --text=${BIN_DIR}/../sentences.txt \ - --output_dir=${train_output_path}/pd_infer_out \ - --phones_dict=dump/phone_id_map.txt -fi - # hifigan -if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then python3 ${BIN_DIR}/../inference.py \ --inference_dir=${train_output_path}/inference \ --am=tacotron2_csmsc \ diff --git a/examples/csmsc/tts3/local/inference.sh b/examples/csmsc/tts3/local/inference.sh index 7052b347..b43fd286 100755 --- a/examples/csmsc/tts3/local/inference.sh +++ b/examples/csmsc/tts3/local/inference.sh @@ -28,7 +28,6 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --phones_dict=dump/phone_id_map.txt fi - # hifigan if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then python3 ${BIN_DIR}/../inference.py \ diff --git a/examples/csmsc/tts3/local/synthesize_e2e.sh b/examples/csmsc/tts3/local/synthesize_e2e.sh index 512e062b..8130eff1 100755 --- a/examples/csmsc/tts3/local/synthesize_e2e.sh +++ b/examples/csmsc/tts3/local/synthesize_e2e.sh @@ -109,6 +109,6 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then --lang=zh \ --text=${BIN_DIR}/../sentences.txt \ --output_dir=${train_output_path}/test_e2e \ - --phones_dict=dump/phone_id_map.txt \ - --inference_dir=${train_output_path}/inference + --phones_dict=dump/phone_id_map.txt #\ + # --inference_dir=${train_output_path}/inference fi diff --git a/examples/ljspeech/tts3/local/synthesize.sh b/examples/ljspeech/tts3/local/synthesize.sh index 6dc34274..0733e96f 100755 --- a/examples/ljspeech/tts3/local/synthesize.sh +++ b/examples/ljspeech/tts3/local/synthesize.sh @@ -26,7 +26,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then fi # hifigan -if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then FLAGS_allocator_strategy=naive_best_fit \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \ python3 ${BIN_DIR}/../synthesize.py \ diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 4b63e1e3..49dd7b35 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -40,7 +40,6 @@ from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['ASRExecutor'] - @cli_register( name='paddlespeech.asr', description='Speech to text infer command.') class ASRExecutor(BaseExecutor): @@ -125,6 +124,7 @@ class ASRExecutor(BaseExecutor): """ Init model and other resources from a specific path. """ + logger.info("start to init the model") if hasattr(self, 'model'): logger.info('Model had been initialized.') return @@ -140,14 +140,15 @@ class ASRExecutor(BaseExecutor): res_path, self.pretrained_models[tag]['ckpt_path'] + ".pdparams") logger.info(res_path) - logger.info(self.cfg_path) - logger.info(self.ckpt_path) + else: self.cfg_path = os.path.abspath(cfg_path) self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams") self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) - + logger.info(self.cfg_path) + logger.info(self.ckpt_path) + #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) @@ -176,7 +177,6 @@ class ASRExecutor(BaseExecutor): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) self.config.decode.decoding_method = decode_method - else: raise Exception("wrong type") model_name = model_type[:model_type.rindex( @@ -254,12 +254,14 @@ class ASRExecutor(BaseExecutor): else: raise Exception("wrong type") + logger.info("audio feat process success") + @paddle.no_grad() def infer(self, model_type: str): """ Model inference and result stored in self.output. """ - + logger.info("start to infer the model to get the output") cfg = self.config.decode audio = self._inputs["audio"] audio_len = self._inputs["audio_len"] @@ -276,17 +278,22 @@ class ASRExecutor(BaseExecutor): self._outputs["result"] = result_transcripts[0] elif "conformer" in model_type or "transformer" in model_type: - result_transcripts = self.model.decode( - audio, - audio_len, - text_feature=self.text_feature, - decoding_method=cfg.decoding_method, - beam_size=cfg.beam_size, - ctc_weight=cfg.ctc_weight, - decoding_chunk_size=cfg.decoding_chunk_size, - num_decoding_left_chunks=cfg.num_decoding_left_chunks, - simulate_streaming=cfg.simulate_streaming) - self._outputs["result"] = result_transcripts[0][0] + logger.info(f"we will use the transformer like model : {model_type}") + try: + result_transcripts = self.model.decode( + audio, + audio_len, + text_feature=self.text_feature, + decoding_method=cfg.decoding_method, + beam_size=cfg.beam_size, + ctc_weight=cfg.ctc_weight, + decoding_chunk_size=cfg.decoding_chunk_size, + num_decoding_left_chunks=cfg.num_decoding_left_chunks, + simulate_streaming=cfg.simulate_streaming) + self._outputs["result"] = result_transcripts[0][0] + except Exception as e: + logger.exception(e) + else: raise Exception("invalid model name") diff --git a/paddlespeech/cli/asr/pretrained_models.py b/paddlespeech/cli/asr/pretrained_models.py index a16c4750..cc52c751 100644 --- a/paddlespeech/cli/asr/pretrained_models.py +++ b/paddlespeech/cli/asr/pretrained_models.py @@ -88,6 +88,8 @@ model_alias = { "paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline", "conformer": "paddlespeech.s2t.models.u2:U2Model", + "conformer_online": + "paddlespeech.s2t.models.u2:U2Model", "transformer": "paddlespeech.s2t.models.u2:U2Model", "wenetspeech": diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 6a98607b..9b66126e 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -279,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer): # TODO(Hui Zhang): if end_flag.sum() == running_size: if end_flag.cast(paddle.int64).sum() == running_size: break - + # 2.1 Forward decoder step hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( running_size, 1, 1).to(device) # (B*N, i, i) # logp: (B*N, vocab) logp, cache = self.decoder.forward_one_step( encoder_out, encoder_mask, hyps, hyps_mask, cache) - # 2.2 First beam prune: select topk best prob at current time top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) top_k_logp = mask_finished_scores(top_k_logp, end_flag) @@ -708,11 +707,11 @@ class U2BaseModel(ASRInterface, nn.Layer): batch_size = feats.shape[0] if decoding_method in ['ctc_prefix_beam_search', 'attention_rescoring'] and batch_size > 1: - logger.fatal( + logger.error( f'decoding mode {decoding_method} must be running with batch_size == 1' ) + logger.error(f"current batch_size is {batch_size}") sys.exit(1) - if decoding_method == 'attention': hyps = self.recognize( feats, diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index 33ad472d..1bb15873 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase): # init once if self._ext_scorer is not None: return - + if language_model_path != '': logger.info("begin to initialize the external scorer " "for decoding") diff --git a/paddlespeech/server/README.md b/paddlespeech/server/README.md index 819fe440..3ac68dae 100644 --- a/paddlespeech/server/README.md +++ b/paddlespeech/server/README.md @@ -35,3 +35,16 @@ ```bash paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav ``` + + ## Online ASR Server + +### Lanuch online asr server +``` +paddlespeech_server start --config_file conf/ws_conformer_application.yaml +``` + +### Access online asr server + +``` +paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav +``` \ No newline at end of file diff --git a/paddlespeech/server/README_cn.md b/paddlespeech/server/README_cn.md index c0a4a733..5f235313 100644 --- a/paddlespeech/server/README_cn.md +++ b/paddlespeech/server/README_cn.md @@ -35,3 +35,17 @@ ```bash paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav ``` + +## 流式ASR + +### 启动流式语音识别服务 + +``` +paddlespeech_server start --config_file conf/ws_conformer_application.yaml +``` + +### 访问流式语音识别服务 + +``` +paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav +``` \ No newline at end of file diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index cb802ce5..45469178 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -277,11 +277,12 @@ class ASRClientExecutor(BaseExecutor): lang=lang, audio_format=audio_format) time_end = time.time() - logger.info(res.json()) + logger.info(res) logger.info("Response time %f s." % (time_end - time_start)) return True except Exception as e: logger.error("Failed to speech recognition.") + logger.error(e) return False @stats_wrapper @@ -299,9 +300,10 @@ class ASRClientExecutor(BaseExecutor): logging.info("asr websocket client start") handler = ASRAudioHandler(server_ip, port) loop = asyncio.get_event_loop() - loop.run_until_complete(handler.run(input)) + res = loop.run_until_complete(handler.run(input)) logging.info("asr websocket client finished") + return res['asr_results'] @cli_client_register( name='paddlespeech_client.cls', description='visit cls service') diff --git a/paddlespeech/server/conf/ws_application.yaml b/paddlespeech/server/conf/ws_application.yaml index b958bdf6..dee8d78b 100644 --- a/paddlespeech/server/conf/ws_application.yaml +++ b/paddlespeech/server/conf/ws_application.yaml @@ -41,11 +41,7 @@ asr_online: shift_ms: 40 sample_rate: 16000 sample_width: 2 - - vad_conf: - aggressiveness: 2 - sample_rate: 16000 - frame_duration_ms: 20 - sample_width: 2 - padding_ms: 200 - padding_ratio: 0.9 + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 20 # ms + shift_ms: 10 # ms diff --git a/paddlespeech/server/conf/ws_conformer_application.yaml b/paddlespeech/server/conf/ws_conformer_application.yaml new file mode 100644 index 00000000..e14833de --- /dev/null +++ b/paddlespeech/server/conf/ws_conformer_application.yaml @@ -0,0 +1,45 @@ +# This is the parameter configuration file for PaddleSpeech Serving. + +################################################################################# +# SERVER SETTING # +################################################################################# +host: 0.0.0.0 +port: 8090 + +# The task format in the engin_list is: _ +# task choices = ['asr_online', 'tts_online'] +# protocol = ['websocket', 'http'] (only one can be selected). +# websocket only support online engine type. +protocol: 'websocket' +engine_list: ['asr_online'] + + +################################################################################# +# ENGINE CONFIG # +################################################################################# + +################################### ASR ######################################### +################### speech task: asr; engine_type: online ####################### +asr_online: + model_type: 'conformer_online_multicn' + am_model: # the pdmodel file of am static model [optional] + am_params: # the pdiparams file of am static model [optional] + lang: 'zh' + sample_rate: 16000 + cfg_path: + decode_method: + force_yes: True + + am_predictor_conf: + device: # set 'gpu:id' or 'cpu' + switch_ir_optim: True + glog_info: False # True -> print glog + summary: True # False -> do not show predictor config + + chunk_buffer_conf: + window_n: 7 # frame + shift_n: 4 # frame + window_ms: 25 # ms + shift_ms: 10 # ms + sample_rate: 16000 + sample_width: 2 \ No newline at end of file diff --git a/paddlespeech/server/engine/asr/online/asr_engine.py b/paddlespeech/server/engine/asr/online/asr_engine.py index ca82b615..758cbaab 100644 --- a/paddlespeech/server/engine/asr/online/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/asr_engine.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import os from typing import Optional @@ -20,12 +21,19 @@ from numpy import float32 from yacs.config import CfgNode from paddlespeech.cli.asr.infer import ASRExecutor +from paddlespeech.cli.asr.infer import model_alias from paddlespeech.cli.log import logger +from paddlespeech.cli.utils import download_and_decompress from paddlespeech.cli.utils import MODEL_HOME from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.modules.ctc import CTCDecoder +from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.s2t.utils.dynamic_import import dynamic_import +from paddlespeech.s2t.utils.tensor_utils import add_sos_eos +from paddlespeech.s2t.utils.tensor_utils import pad_sequence from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.server.engine.asr.online.ctc_search import CTCPrefixBeamSearch from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import pcm2float from paddlespeech.server.utils.paddle_predictor import init_predictor @@ -35,9 +43,9 @@ __all__ = ['ASREngine'] pretrained_models = { "deepspeech2online_aishell-zh-16k": { 'url': - 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar.gz', + 'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz', 'md5': - 'd5e076217cf60486519f72c217d21b9b', + '23e16c69730a1cb5d735c98c83c21e16', 'cfg_path': 'model.yaml', 'ckpt_path': @@ -51,16 +59,543 @@ pretrained_models = { 'lm_md5': '29e02312deb2e59b3c8686c7966d4fe3' }, + "conformer_online_multicn-zh-16k": { + 'url': + 'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.3.model.tar.gz', + 'md5': + '0ac93d390552336f2a906aec9e33c5fa', + 'cfg_path': + 'model.yaml', + 'ckpt_path': + 'exp/chunk_conformer/checkpoints/multi_cn', + 'model': + 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', + 'params': + 'exp/chunk_conformer/checkpoints/multi_cn.pdparams', + 'lm_url': + 'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm', + 'lm_md5': + '29e02312deb2e59b3c8686c7966d4fe3' + }, } +# ASR server connection process class +class PaddleASRConnectionHanddler: + def __init__(self, asr_engine): + """Init a Paddle ASR Connection Handler instance + + Args: + asr_engine (ASREngine): the global asr engine + """ + super().__init__() + logger.info( + "create an paddle asr connection handler to process the websocket connection" + ) + self.config = asr_engine.config + self.model_config = asr_engine.executor.config + self.asr_engine = asr_engine + + self.init() + self.reset() + + def init(self): + # model_type, sample_rate and text_feature is shared for deepspeech2 and conformer + self.model_type = self.asr_engine.executor.model_type + self.sample_rate = self.asr_engine.executor.sample_rate + # tokens to text + self.text_feature = self.asr_engine.executor.text_feature + + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + from paddlespeech.s2t.io.collator import SpeechCollator + self.am_predictor = self.asr_engine.executor.am_predictor + + self.collate_fn_test = SpeechCollator.from_config(self.model_config) + self.decoder = CTCDecoder( + odim=self.model_config.output_dim, # is in vocab + enc_n_units=self.model_config.rnn_layer_size * 2, + blank_id=self.model_config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.model_config.get('ctc_grad_norm_type', + None)) + + cfg = self.model_config.decode + decode_batch_size = 1 # for online + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + # frame window samples length and frame shift samples length + + self.win_length = int(self.model_config.window_ms / 1000 * + self.sample_rate) + self.n_shift = int(self.model_config.stride_ms / 1000 * + self.sample_rate) + + elif "conformer" in self.model_type or "transformer" in self.model_type: + # acoustic model + self.model = self.asr_engine.executor.model + + # ctc decoding config + self.ctc_decode_config = self.asr_engine.executor.config.decode + self.searcher = CTCPrefixBeamSearch(self.ctc_decode_config) + + # extract feat, new only fbank in conformer model + self.preprocess_conf = self.model_config.preprocess_config + self.preprocess_args = {"train": False} + self.preprocessing = Transformation(self.preprocess_conf) + + # frame window samples length and frame shift samples length + self.win_length = self.preprocess_conf.process[0]['win_length'] + self.n_shift = self.preprocess_conf.process[0]['n_shift'] + + def extract_feat(self, samples): + if "deepspeech2online" in self.model_type: + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 + + # pcm16 -> pcm 32 + # pcm2float will change the orignal samples, + # so we shoule do pcm2float before concatenate + samples = pcm2float(samples) + + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The connection remain the audio samples: {self.remained_wav.shape}" + ) + + # read audio + speech_segment = SpeechSegment.from_pcm( + self.remained_wav, self.sample_rate, transcript=" ") + # audio augment + self.collate_fn_test.augmentation.transform_audio(speech_segment) + + # extract speech feature + spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( + speech_segment, self.collate_fn_test.keep_transcription_text) + # CMVN spectrum + if self.collate_fn_test._normalizer: + spectrum = self.collate_fn_test._normalizer.apply(spectrum) + + # spectrum augment + audio = self.collate_fn_test.augmentation.transform_feature( + spectrum) + + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + # audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) + + if self.cached_feat is None: + self.cached_feat = audio + else: + assert (len(audio.shape) == 3) + assert (len(self.cached_feat.shape) == 3) + self.cached_feat = paddle.concat( + [self.cached_feat, audio], axis=1) + + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + + self.num_frames += audio_len + self.remained_wav = self.remained_wav[self.n_shift * audio_len:] + + logger.info( + f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" + ) + elif "conformer_online" in self.model_type: + logger.info("Online ASR extract the feat") + samples = np.frombuffer(samples, dtype=np.int16) + assert samples.ndim == 1 + + logger.info(f"This package receive {samples.shape[0]} pcm data") + self.num_samples += samples.shape[0] + + # self.reamined_wav stores all the samples, + # include the original remained_wav and this package samples + if self.remained_wav is None: + self.remained_wav = samples + else: + assert self.remained_wav.ndim == 1 + self.remained_wav = np.concatenate([self.remained_wav, samples]) + logger.info( + f"The connection remain the audio samples: {self.remained_wav.shape}" + ) + if len(self.remained_wav) < self.win_length: + return 0 + + # fbank + x_chunk = self.preprocessing(self.remained_wav, + **self.preprocess_args) + x_chunk = paddle.to_tensor( + x_chunk, dtype="float32").unsqueeze(axis=0) + if self.cached_feat is None: + self.cached_feat = x_chunk + else: + assert (len(x_chunk.shape) == 3) + assert (len(self.cached_feat.shape) == 3) + self.cached_feat = paddle.concat( + [self.cached_feat, x_chunk], axis=1) + + # set the feat device + if self.device is None: + self.device = self.cached_feat.place + + num_frames = x_chunk.shape[1] + self.num_frames += num_frames + self.remained_wav = self.remained_wav[self.n_shift * num_frames:] + + logger.info( + f"process the audio feature success, the connection feat shape: {self.cached_feat.shape}" + ) + logger.info( + f"After extract feat, the connection remain the audio samples: {self.remained_wav.shape}" + ) + # logger.info(f"accumulate samples: {self.num_samples}") + + def reset(self): + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + # for deepspeech2 + self.chunk_state_h_box = copy.deepcopy( + self.asr_engine.executor.chunk_state_h_box) + self.chunk_state_c_box = copy.deepcopy( + self.asr_engine.executor.chunk_state_c_box) + self.decoder.reset_decoder(batch_size=1) + + # for conformer online + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.encoder_out = None + self.cached_feat = None + self.remained_wav = None + self.offset = 0 + self.num_samples = 0 + self.device = None + self.hyps = [] + self.num_frames = 0 + self.chunk_num = 0 + self.global_frame_offset = 0 + self.result_transcripts = [''] + + def decode(self, is_finished=False): + if "deepspeech2online" in self.model_type: + # x_chunk 是特征数据 + decoding_chunk_size = 1 # decoding_chunk_size=1 in deepspeech2 model + context = 7 # context=7 in deepspeech2 model + subsampling = 4 # subsampling=4 in deepspeech2 model + stride = subsampling * decoding_chunk_size + cached_feature_num = context - subsampling + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + + if self.cached_feat is None: + logger.info("no audio feat, please input more pcm data") + return + + num_frames = self.cached_feat.shape[1] + logger.info( + f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" + ) + # the cached feat must be larger decoding_window + if num_frames < decoding_window and not is_finished: + logger.info( + f"frame feat num is less than {decoding_window}, please input more pcm data" + ) + return None, None + + # if is_finished=True, we need at least context frames + if num_frames < context: + logger.info( + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + ) + return None, None + logger.info("start to do model forward") + # num_frames - context + 1 ensure that current frame can get context window + if is_finished: + # if get the finished chunk, we need process the last context + left_frames = context + else: + # we only process decoding_window frames for one chunk + left_frames = decoding_window + + for cur in range(0, num_frames - left_frames + 1, stride): + end = min(cur + decoding_window, num_frames) + # extract the audio + x_chunk = self.cached_feat[:, cur:end, :].numpy() + x_chunk_lens = np.array([x_chunk.shape[1]]) + trans_best = self.decode_one_chunk(x_chunk, x_chunk_lens) + + self.result_transcripts = [trans_best] + + self.cached_feat = self.cached_feat[:, end - cached_feature_num:, :] + # return trans_best[0] + elif "conformer" in self.model_type or "transformer" in self.model_type: + try: + logger.info( + f"we will use the transformer like model : {self.model_type}" + ) + self.advance_decoding(is_finished) + self.update_result() + + except Exception as e: + logger.exception(e) + else: + raise Exception("invalid model name") + + @paddle.no_grad() + def decode_one_chunk(self, x_chunk, x_chunk_lens): + logger.info("start to decoce one chunk with deepspeech2 model") + input_names = self.am_predictor.get_input_names() + audio_handle = self.am_predictor.get_input_handle(input_names[0]) + audio_len_handle = self.am_predictor.get_input_handle(input_names[1]) + h_box_handle = self.am_predictor.get_input_handle(input_names[2]) + c_box_handle = self.am_predictor.get_input_handle(input_names[3]) + + audio_handle.reshape(x_chunk.shape) + audio_handle.copy_from_cpu(x_chunk) + + audio_len_handle.reshape(x_chunk_lens.shape) + audio_len_handle.copy_from_cpu(x_chunk_lens) + + h_box_handle.reshape(self.chunk_state_h_box.shape) + h_box_handle.copy_from_cpu(self.chunk_state_h_box) + + c_box_handle.reshape(self.chunk_state_c_box.shape) + c_box_handle.copy_from_cpu(self.chunk_state_c_box) + + output_names = self.am_predictor.get_output_names() + output_handle = self.am_predictor.get_output_handle(output_names[0]) + output_lens_handle = self.am_predictor.get_output_handle( + output_names[1]) + output_state_h_handle = self.am_predictor.get_output_handle( + output_names[2]) + output_state_c_handle = self.am_predictor.get_output_handle( + output_names[3]) + + self.am_predictor.run() + + output_chunk_probs = output_handle.copy_to_cpu() + output_chunk_lens = output_lens_handle.copy_to_cpu() + self.chunk_state_h_box = output_state_h_handle.copy_to_cpu() + self.chunk_state_c_box = output_state_c_handle.copy_to_cpu() + + self.decoder.next(output_chunk_probs, output_chunk_lens) + trans_best, trans_beam = self.decoder.decode() + logger.info(f"decode one best result: {trans_best[0]}") + return trans_best[0] + + @paddle.no_grad() + def advance_decoding(self, is_finished=False): + logger.info("start to decode with advanced_decoding method") + cfg = self.ctc_decode_config + decoding_chunk_size = cfg.decoding_chunk_size + num_decoding_left_chunks = cfg.num_decoding_left_chunks + + assert decoding_chunk_size > 0 + subsampling = self.model.encoder.embed.subsampling_rate + context = self.model.encoder.embed.right_context + 1 + stride = subsampling * decoding_chunk_size + cached_feature_num = context - subsampling # processed chunk feature cached for next chunk + + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + if self.cached_feat is None: + logger.info("no audio feat, please input more pcm data") + return + + num_frames = self.cached_feat.shape[1] + logger.info( + f"Required decoding window {decoding_window} frames, and the connection has {num_frames} frames" + ) + + # the cached feat must be larger decoding_window + if num_frames < decoding_window and not is_finished: + logger.info( + f"frame feat num is less than {decoding_window}, please input more pcm data" + ) + return None, None + + # if is_finished=True, we need at least context frames + if num_frames < context: + logger.info( + "flast {num_frames} is less than context {context} frames, and we cannot do model forward" + ) + return None, None + + logger.info("start to do model forward") + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + outputs = [] + + # num_frames - context + 1 ensure that current frame can get context window + if is_finished: + # if get the finished chunk, we need process the last context + left_frames = context + else: + # we only process decoding_window frames for one chunk + left_frames = decoding_window + + # record the end for removing the processed feat + end = None + for cur in range(0, num_frames - left_frames + 1, stride): + end = min(cur + decoding_window, num_frames) + + self.chunk_num += 1 + chunk_xs = self.cached_feat[:, cur:end, :] + (y, self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + chunk_xs, self.offset, required_cache_size, + self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) + outputs.append(y) + + # update the offset + self.offset += y.shape[1] + + ys = paddle.cat(outputs, 1) + if self.encoder_out is None: + self.encoder_out = ys + else: + self.encoder_out = paddle.concat([self.encoder_out, ys], axis=1) + + # get the ctc probs + ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + + self.searcher.search(ctc_probs, self.cached_feat.place) + + self.hyps = self.searcher.get_one_best_hyps() + assert self.cached_feat.shape[0] == 1 + assert end >= cached_feature_num + + self.cached_feat = self.cached_feat[0, end - + cached_feature_num:, :].unsqueeze(0) + assert len( + self.cached_feat.shape + ) == 3, f"current cache feat shape is: {self.cached_feat.shape}" + + logger.info( + f"This connection handler encoder out shape: {self.encoder_out.shape}" + ) + + def update_result(self): + logger.info("update the final result") + hyps = self.hyps + self.result_transcripts = [ + self.text_feature.defeaturize(hyp) for hyp in hyps + ] + self.result_tokenids = [hyp for hyp in hyps] + + def get_result(self): + if len(self.result_transcripts) > 0: + return self.result_transcripts[0] + else: + return '' + + @paddle.no_grad() + def rescoring(self): + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + return + + logger.info("rescoring the final result") + if "attention_rescoring" != self.ctc_decode_config.decoding_method: + return + + self.searcher.finalize_search() + self.update_result() + + beam_size = self.ctc_decode_config.beam_size + hyps = self.searcher.get_hyps() + if hyps is None or len(hyps) == 0: + return + + # assert len(hyps) == beam_size + hyp_list = [] + for hyp in hyps: + hyp_content = hyp[0] + # Prevent the hyp is empty + if len(hyp_content) == 0: + hyp_content = (self.model.ctc.blank_id, ) + hyp_content = paddle.to_tensor( + hyp_content, place=self.device, dtype=paddle.long) + hyp_list.append(hyp_content) + hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) + hyps_lens = paddle.to_tensor( + [len(hyp[0]) for hyp in hyps], place=self.device, + dtype=paddle.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, + self.model.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + + encoder_out = self.encoder_out.repeat(beam_size, 1, 1) + encoder_mask = paddle.ones( + (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) + decoder_out, _ = self.model.decoder( + encoder_out, encoder_mask, hyps_pad, + hyps_lens) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain + decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) + decoder_out = decoder_out.numpy() + + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. + score += decoder_out[i][len(hyp[0])][self.model.eos] + # add ctc score (which in ln domain) + score += hyp[1] * self.ctc_decode_config.ctc_weight + if score > best_score: + best_score = score + best_index = i + + # update the one best result + logger.info(f"best index: {best_index}") + self.hyps = [hyps[best_index][0]] + self.update_result() + + class ASRServerExecutor(ASRExecutor): def __init__(self): super().__init__() pass + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + support_models = list(pretrained_models.keys()) + assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( + tag, '\n\t\t'.join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + + return decompressed_path + def _init_from_path(self, - model_type: str='wenetspeech', + model_type: str='deepspeech2online_aishell', am_model: Optional[os.PathLike]=None, am_params: Optional[os.PathLike]=None, lang: str='zh', @@ -71,12 +606,15 @@ class ASRServerExecutor(ASRExecutor): """ Init model and other resources from a specific path. """ - + self.model_type = model_type + self.sample_rate = sample_rate if cfg_path is None or am_model is None or am_params is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' tag = model_type + '-' + lang + '-' + sample_rate_str + logger.info(f"Load the pretrained model, tag = {tag}") res_path = self._get_pretrained_path(tag) # wenetspeech_zh self.res_path = res_path + self.cfg_path = os.path.join(res_path, pretrained_models[tag]['cfg_path']) @@ -85,9 +623,6 @@ class ASRServerExecutor(ASRExecutor): self.am_params = os.path.join(res_path, pretrained_models[tag]['params']) logger.info(res_path) - logger.info(self.cfg_path) - logger.info(self.am_model) - logger.info(self.am_params) else: self.cfg_path = os.path.abspath(cfg_path) self.am_model = os.path.abspath(am_model) @@ -95,6 +630,10 @@ class ASRServerExecutor(ASRExecutor): self.res_path = os.path.dirname( os.path.dirname(os.path.abspath(self.cfg_path))) + logger.info(self.cfg_path) + logger.info(self.am_model) + logger.info(self.am_params) + #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) @@ -112,59 +651,107 @@ class ASRServerExecutor(ASRExecutor): lm_url = pretrained_models[tag]['lm_url'] lm_md5 = pretrained_models[tag]['lm_md5'] + logger.info(f"Start to load language model {lm_url}") self.download_lm( lm_url, os.path.dirname(self.config.decode.lang_model_path), lm_md5) - elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type: - raise Exception("wrong type") + elif "conformer" in model_type or "transformer" in model_type: + logger.info("start to create the stream conformer asr engine") + if self.config.spm_model_prefix: + self.config.spm_model_prefix = os.path.join( + self.res_path, self.config.spm_model_prefix) + self.vocab = self.config.vocab_filepath + self.text_feature = TextFeaturizer( + unit_type=self.config.unit_type, + vocab=self.config.vocab_filepath, + spm_model_prefix=self.config.spm_model_prefix) + # update the decoding method + if decode_method: + self.config.decode.decoding_method = decode_method + + # we only support ctc_prefix_beam_search and attention_rescoring dedoding method + # Generally we set the decoding_method to attention_rescoring + if self.config.decode.decoding_method not in [ + "ctc_prefix_beam_search", "attention_rescoring" + ]: + logger.info( + "we set the decoding_method to attention_rescoring") + self.config.decode.decoding = "attention_rescoring" + assert self.config.decode.decoding_method in [ + "ctc_prefix_beam_search", "attention_rescoring" + ], f"we only support ctc_prefix_beam_search and attention_rescoring dedoding method, current decoding method is {self.config.decode.decoding_method}" else: raise Exception("wrong type") + if "deepspeech2online" in model_type or "deepspeech2offline" in model_type: + # AM predictor + logger.info("ASR engine start to init the am predictor") + self.am_predictor_conf = am_predictor_conf + self.am_predictor = init_predictor( + model_file=self.am_model, + params_file=self.am_params, + predictor_conf=self.am_predictor_conf) - # AM predictor - self.am_predictor_conf = am_predictor_conf - self.am_predictor = init_predictor( - model_file=self.am_model, - params_file=self.am_params, - predictor_conf=self.am_predictor_conf) - - # decoder - self.decoder = CTCDecoder( - odim=self.config.output_dim, # is in vocab - enc_n_units=self.config.rnn_layer_size * 2, - blank_id=self.config.blank_id, - dropout_rate=0.0, - reduction=True, # sum - batch_average=True, # sum / batch_size - grad_norm_type=self.config.get('ctc_grad_norm_type', None)) - - # init decoder - cfg = self.config.decode - decode_batch_size = 1 # for online - self.decoder.init_decoder( - decode_batch_size, self.text_feature.vocab_list, - cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, - cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, - cfg.num_proc_bsearch) - - # init state box - self.chunk_state_h_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - self.chunk_state_c_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) + # decoder + logger.info("ASR engine start to create the ctc decoder instance") + self.decoder = CTCDecoder( + odim=self.config.output_dim, # is in vocab + enc_n_units=self.config.rnn_layer_size * 2, + blank_id=self.config.blank_id, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True, # sum / batch_size + grad_norm_type=self.config.get('ctc_grad_norm_type', None)) + + # init decoder + logger.info("ASR engine start to init the ctc decoder") + cfg = self.config.decode + decode_batch_size = 1 # for online + self.decoder.init_decoder( + decode_batch_size, self.text_feature.vocab_list, + cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta, + cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n, + cfg.num_proc_bsearch) + + # init state box + self.chunk_state_h_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + elif "conformer" in model_type or "transformer" in model_type: + model_name = model_type[:model_type.rindex( + '_')] # model_type: {model_name}_{dataset} + logger.info(f"model name: {model_name}") + model_class = dynamic_import(model_name, model_alias) + model_conf = self.config + model = model_class.from_config(model_conf) + self.model = model + self.model.eval() + + # load model + model_dict = paddle.load(self.am_model) + self.model.set_state_dict(model_dict) + logger.info("create the transformer like model success") + + # update the ctc decoding + self.searcher = CTCPrefixBeamSearch(self.config.decode) + self.transformer_decode_reset() def reset_decoder_and_chunk(self): """reset decoder and chunk state for an new audio """ - self.decoder.reset_decoder(batch_size=1) - # init state box, for new audio request - self.chunk_state_h_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) - self.chunk_state_c_box = np.zeros( - (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), - dtype=float32) + if "deepspeech2online" in self.model_type or "deepspeech2offline" in self.model_type: + self.decoder.reset_decoder(batch_size=1) + # init state box, for new audio request + self.chunk_state_h_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + self.chunk_state_c_box = np.zeros( + (self.config.num_rnn_layers, 1, self.config.rnn_layer_size), + dtype=float32) + elif "conformer" in self.model_type or "transformer" in self.model_type: + self.transformer_decode_reset() def decode_one_chunk(self, x_chunk, x_chunk_lens, model_type: str): """decode one chunk @@ -175,8 +762,9 @@ class ASRServerExecutor(ASRExecutor): model_type (str): online model type Returns: - [type]: [description] + str: one best result """ + logger.info("start to decoce chunk by chunk") if "deepspeech2online" in model_type: input_names = self.am_predictor.get_input_names() audio_handle = self.am_predictor.get_input_handle(input_names[0]) @@ -215,14 +803,142 @@ class ASRServerExecutor(ASRExecutor): self.decoder.next(output_chunk_probs, output_chunk_lens) trans_best, trans_beam = self.decoder.decode() - + logger.info(f"decode one best result: {trans_best[0]}") return trans_best[0] elif "conformer" in model_type or "transformer" in model_type: - raise Exception("invalid model name") + try: + logger.info( + f"we will use the transformer like model : {self.model_type}" + ) + self.advanced_decoding(x_chunk, x_chunk_lens) + self.update_result() + + return self.result_transcripts[0] + except Exception as e: + logger.exception(e) else: raise Exception("invalid model name") + def advanced_decoding(self, xs: paddle.Tensor, x_chunk_lens): + logger.info("start to decode with advanced_decoding method") + encoder_out, encoder_mask = self.encoder_forward(xs) + ctc_probs = self.model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + ctc_probs = ctc_probs.squeeze(0) + self.searcher.search(ctc_probs, xs.place) + # update the one best result + self.hyps = self.searcher.get_one_best_hyps() + + # now we supprot ctc_prefix_beam_search and attention_rescoring + if "attention_rescoring" in self.config.decode.decoding_method: + self.rescoring(encoder_out, xs.place) + + def encoder_forward(self, xs): + logger.info("get the model out from the feat") + cfg = self.config.decode + decoding_chunk_size = cfg.decoding_chunk_size + num_decoding_left_chunks = cfg.num_decoding_left_chunks + + assert decoding_chunk_size > 0 + subsampling = self.model.encoder.embed.subsampling_rate + context = self.model.encoder.embed.right_context + 1 + stride = subsampling * decoding_chunk_size + + # decoding window for model + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.shape[1] + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + logger.info("start to do model forward") + outputs = [] + + # num_frames - context + 1 ensure that current frame can get context window + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + (y, self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) = self.model.encoder.forward_chunk( + chunk_xs, self.offset, required_cache_size, + self.subsampling_cache, self.elayers_output_cache, + self.conformer_cnn_cache) + outputs.append(y) + self.offset += y.shape[1] + + ys = paddle.cat(outputs, 1) + masks = paddle.ones([1, ys.shape[1]], dtype=paddle.bool) + masks = masks.unsqueeze(1) + return ys, masks + + def rescoring(self, encoder_out, device): + logger.info("start to rescoring the hyps") + beam_size = self.config.decode.beam_size + hyps = self.searcher.get_hyps() + assert len(hyps) == beam_size + + hyp_list = [] + for hyp in hyps: + hyp_content = hyp[0] + # Prevent the hyp is empty + if len(hyp_content) == 0: + hyp_content = (self.model.ctc.blank_id, ) + hyp_content = paddle.to_tensor( + hyp_content, place=device, dtype=paddle.long) + hyp_list.append(hyp_content) + hyps_pad = pad_sequence(hyp_list, True, self.model.ignore_id) + hyps_lens = paddle.to_tensor( + [len(hyp[0]) for hyp in hyps], place=device, + dtype=paddle.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, + self.model.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + + encoder_out = encoder_out.repeat(beam_size, 1, 1) + encoder_mask = paddle.ones( + (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) + decoder_out, _ = self.model.decoder( + encoder_out, encoder_mask, hyps_pad, + hyps_lens) # (beam_size, max_hyps_len, vocab_size) + # ctc score in ln domain + decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) + decoder_out = decoder_out.numpy() + + # Only use decoder score for rescoring + best_score = -float('inf') + best_index = 0 + # hyps is List[(Text=List[int], Score=float)], len(hyps)=beam_size + for i, hyp in enumerate(hyps): + score = 0.0 + for j, w in enumerate(hyp[0]): + score += decoder_out[i][j][w] + # last decoder output token is `eos`, for laste decoder input token. + score += decoder_out[i][len(hyp[0])][self.model.eos] + # add ctc score (which in ln domain) + score += hyp[1] * self.config.decode.ctc_weight + if score > best_score: + best_score = score + best_index = i + + # update the one best result + self.hyps = [hyps[best_index][0]] + return hyps[best_index][0] + + def transformer_decode_reset(self): + self.subsampling_cache = None + self.elayers_output_cache = None + self.conformer_cnn_cache = None + self.offset = 0 + # decoding reset + self.searcher.reset() + + def update_result(self): + logger.info("update the final result") + hyps = self.hyps + self.result_transcripts = [ + self.text_feature.defeaturize(hyp) for hyp in hyps + ] + self.result_tokenids = [hyp for hyp in hyps] + def extract_feat(self, samples, sample_rate): """extract feat @@ -234,34 +950,58 @@ class ASRServerExecutor(ASRExecutor): x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] """ - # pcm16 -> pcm 32 - samples = pcm2float(samples) - # read audio - speech_segment = SpeechSegment.from_pcm( - samples, sample_rate, transcript=" ") - # audio augment - self.collate_fn_test.augmentation.transform_audio(speech_segment) + if "deepspeech2online" in self.model_type: + # pcm16 -> pcm 32 + samples = pcm2float(samples) + # read audio + speech_segment = SpeechSegment.from_pcm( + samples, sample_rate, transcript=" ") + # audio augment + self.collate_fn_test.augmentation.transform_audio(speech_segment) - # extract speech feature - spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( - speech_segment, self.collate_fn_test.keep_transcription_text) - # CMVN spectrum - if self.collate_fn_test._normalizer: - spectrum = self.collate_fn_test._normalizer.apply(spectrum) + # extract speech feature + spectrum, transcript_part = self.collate_fn_test._speech_featurizer.featurize( + speech_segment, self.collate_fn_test.keep_transcription_text) + # CMVN spectrum + if self.collate_fn_test._normalizer: + spectrum = self.collate_fn_test._normalizer.apply(spectrum) - # spectrum augment - audio = self.collate_fn_test.augmentation.transform_feature(spectrum) + # spectrum augment + audio = self.collate_fn_test.augmentation.transform_feature( + spectrum) - audio_len = audio.shape[0] - audio = paddle.to_tensor(audio, dtype='float32') - # audio_len = paddle.to_tensor(audio_len) - audio = paddle.unsqueeze(audio, axis=0) + audio_len = audio.shape[0] + audio = paddle.to_tensor(audio, dtype='float32') + # audio_len = paddle.to_tensor(audio_len) + audio = paddle.unsqueeze(audio, axis=0) - x_chunk = audio.numpy() - x_chunk_lens = np.array([audio_len]) + x_chunk = audio.numpy() + x_chunk_lens = np.array([audio_len]) - return x_chunk, x_chunk_lens + return x_chunk, x_chunk_lens + elif "conformer_online" in self.model_type: + + if sample_rate != self.sample_rate: + logger.info(f"audio sample rate {sample_rate} is not match," + "the model sample_rate is {self.sample_rate}") + logger.info(f"ASR Engine use the {self.model_type} to process") + logger.info("Create the preprocess instance") + preprocess_conf = self.config.preprocess_config + preprocess_args = {"train": False} + preprocessing = Transformation(preprocess_conf) + + logger.info("Read the audio file") + logger.info(f"audio shape: {samples.shape}") + # fbank + x_chunk = preprocessing(samples, **preprocess_args) + x_chunk_lens = paddle.to_tensor(x_chunk.shape[0]) + x_chunk = paddle.to_tensor( + x_chunk, dtype="float32").unsqueeze(axis=0) + logger.info( + f"process the audio feature success, feat shape: {x_chunk.shape}" + ) + return x_chunk, x_chunk_lens class ASREngine(BaseEngine): @@ -273,6 +1013,7 @@ class ASREngine(BaseEngine): def __init__(self): super(ASREngine, self).__init__() + logger.info("create the online asr engine instance") def init(self, config: dict) -> bool: """init engine resource @@ -301,7 +1042,10 @@ class ASREngine(BaseEngine): logger.info("Initialize ASR server engine successfully.") return True - def preprocess(self, samples, sample_rate): + def preprocess(self, + samples, + sample_rate, + model_type="deepspeech2online_aishell-zh-16k"): """preprocess Args: @@ -312,6 +1056,7 @@ class ASREngine(BaseEngine): x_chunk (numpy.array): shape[B, T, D] x_chunk_lens (numpy.array): shape[B] """ + # if "deepspeech" in model_type: x_chunk, x_chunk_lens = self.executor.extract_feat(samples, sample_rate) return x_chunk, x_chunk_lens diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py new file mode 100644 index 00000000..8aee0a50 --- /dev/null +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict +import paddle +from paddlespeech.cli.log import logger +from paddlespeech.s2t.utils.utility import log_add + +__all__ = ['CTCPrefixBeamSearch'] + + +class CTCPrefixBeamSearch: + def __init__(self, config): + """Implement the ctc prefix beam search + + Args: + config (yacs.config.CfgNode): _description_ + """ + self.config = config + self.reset() + + @paddle.no_grad() + def search(self, ctc_probs, device, blank_id=0): + """ctc prefix beam search method decode a chunk feature + + Args: + xs (paddle.Tensor): feature data + ctc_probs (paddle.Tensor): the ctc probability of all the tokens + device (paddle.fluid.core_avx.Place): the feature host device, such as CUDAPlace(0). + blank_id (int, optional): the blank id in the vocab. Defaults to 0. + + Returns: + list: the search result + """ + # decode + logger.info("start to ctc prefix search") + + batch_size = 1 + beam_size = self.config.beam_size + maxlen = ctc_probs.shape[0] + + assert len(ctc_probs.shape) == 2 + + # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) + # blank_ending_score and none_blank_ending_score in ln domain + if self.cur_hyps is None: + self.cur_hyps = [(tuple(), (0.0, -float('inf')))] + # 2. CTC beam search step by step + for t in range(0, maxlen): + logp = ctc_probs[t] # (vocab_size,) + # key: prefix, value (pb, pnb), default value(-inf, -inf) + next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) + + # 2.1 First beam prune: select topk best + # do token passing process + top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) + for s in top_k_index: + s = s.item() + ps = logp[s].item() + for prefix, (pb, pnb) in self.cur_hyps: + last = prefix[-1] if len(prefix) > 0 else None + if s == blank_id: # blank + n_pb, n_pnb = next_hyps[prefix] + n_pb = log_add([n_pb, pb + ps, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + elif s == last: + # Update *ss -> *s; + n_pb, n_pnb = next_hyps[prefix] + n_pnb = log_add([n_pnb, pnb + ps]) + next_hyps[prefix] = (n_pb, n_pnb) + # Update *s-s -> *ss, - is for blank + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + else: + n_prefix = prefix + (s, ) + n_pb, n_pnb = next_hyps[n_prefix] + n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) + next_hyps[n_prefix] = (n_pb, n_pnb) + + # 2.2 Second beam prune + next_hyps = sorted( + next_hyps.items(), + key=lambda x: log_add(list(x[1])), + reverse=True) + self.cur_hyps = next_hyps[:beam_size] + + self.hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in self.cur_hyps] + logger.info("ctc prefix search success") + return self.hyps + + def get_one_best_hyps(self): + """Return the one best result + + Returns: + list: the one best result + """ + return [self.hyps[0][0]] + + def get_hyps(self): + """Return the search hyps + + Returns: + list: return the search hyps + """ + return self.hyps + + def reset(self): + """Rest the search cache value + """ + self.cur_hyps = None + self.hyps = None + + def finalize_search(self): + """do nothing in ctc_prefix_beam_search + """ + pass diff --git a/paddlespeech/server/engine/tts/online/tts_engine.py b/paddlespeech/server/engine/tts/online/tts_engine.py index 25a8bc76..c9135b88 100644 --- a/paddlespeech/server/engine/tts/online/tts_engine.py +++ b/paddlespeech/server/engine/tts/online/tts_engine.py @@ -12,24 +12,329 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 +import math +import os import time +from typing import Optional import numpy as np import paddle +import yaml +from yacs.config import CfgNode from paddlespeech.cli.log import logger from paddlespeech.cli.tts.infer import TTSExecutor +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.cli.utils import MODEL_HOME +from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils.audio_process import float2pcm +from paddlespeech.server.utils.util import denorm from paddlespeech.server.utils.util import get_chunks +from paddlespeech.t2s.frontend import English +from paddlespeech.t2s.frontend.zh_frontend import Frontend +from paddlespeech.t2s.modules.normalizer import ZScore + +__all__ = ['TTSEngine'] + +# support online model +pretrained_models = { + # fastspeech2 + "fastspeech2_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip', + 'md5': + '637d28a5e53aa60275612ba4393d5f22', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_76000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + "fastspeech2_cnndecoder_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_cnndecoder_csmsc_ckpt_1.0.0.zip', + 'md5': + '6eb28e22ace73e0ebe7845f86478f89f', + 'config': + 'cnndecoder.yaml', + 'ckpt': + 'snapshot_iter_153000.pdz', + 'speech_stats': + 'speech_stats.npy', + 'phones_dict': + 'phone_id_map.txt', + }, + + # mb_melgan + "mb_melgan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/mb_melgan/mb_melgan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'ee5f0604e20091f0d495b6ec4618b90d', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_1000000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, + + # hifigan + "hifigan_csmsc-zh": { + 'url': + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip', + 'md5': + 'dd40a3d88dfcf64513fba2f0f961ada6', + 'config': + 'default.yaml', + 'ckpt': + 'snapshot_iter_2500000.pdz', + 'speech_stats': + 'feats_stats.npy', + }, +} + +model_alias = { + # acoustic model + "fastspeech2": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2", + "fastspeech2_inference": + "paddlespeech.t2s.models.fastspeech2:FastSpeech2Inference", + + # voc + "mb_melgan": + "paddlespeech.t2s.models.melgan:MelGANGenerator", + "mb_melgan_inference": + "paddlespeech.t2s.models.melgan:MelGANInference", + "hifigan": + "paddlespeech.t2s.models.hifigan:HiFiGANGenerator", + "hifigan_inference": + "paddlespeech.t2s.models.hifigan:HiFiGANInference", +} __all__ = ['TTSEngine'] class TTSServerExecutor(TTSExecutor): - def __init__(self): + def __init__(self, am_block, am_pad, voc_block, voc_pad): super().__init__() - pass + self.am_block = am_block + self.am_pad = am_pad + self.voc_block = voc_block + self.voc_pad = voc_pad + + def get_model_info(self, + field: str, + model_name: str, + ckpt: Optional[os.PathLike], + stat: Optional[os.PathLike]): + """get model information + + Args: + field (str): am or voc + model_name (str): model type, support fastspeech2, higigan, mb_melgan + ckpt (Optional[os.PathLike]): ckpt file + stat (Optional[os.PathLike]): stat file, including mean and standard deviation + + Returns: + [module]: model module + [Tensor]: mean + [Tensor]: standard deviation + """ + + model_class = dynamic_import(model_name, model_alias) + + if field == "am": + odim = self.am_config.n_mels + model = model_class( + idim=self.vocab_size, odim=odim, **self.am_config["model"]) + model.set_state_dict(paddle.load(ckpt)["main_params"]) + + elif field == "voc": + model = model_class(**self.voc_config["generator_params"]) + model.set_state_dict(paddle.load(ckpt)["generator_params"]) + model.remove_weight_norm() + + else: + logger.error("Please set correct field, am or voc") + + model.eval() + model_mu, model_std = np.load(stat) + model_mu = paddle.to_tensor(model_mu) + model_std = paddle.to_tensor(model_std) + + return model, model_mu, model_std + + def _get_pretrained_path(self, tag: str) -> os.PathLike: + """ + Download and returns pretrained resources path of current task. + """ + support_models = list(pretrained_models.keys()) + assert tag in pretrained_models, 'The model "{}" you want to use has not been supported, please choose other models.\nThe support models includes:\n\t\t{}\n'.format( + tag, '\n\t\t'.join(support_models)) + + res_path = os.path.join(MODEL_HOME, tag) + decompressed_path = download_and_decompress(pretrained_models[tag], + res_path) + decompressed_path = os.path.abspath(decompressed_path) + logger.info( + 'Use pretrained model stored in: {}'.format(decompressed_path)) + return decompressed_path + + def _init_from_path( + self, + am: str='fastspeech2_csmsc', + am_config: Optional[os.PathLike]=None, + am_ckpt: Optional[os.PathLike]=None, + am_stat: Optional[os.PathLike]=None, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, + voc: str='mb_melgan_csmsc', + voc_config: Optional[os.PathLike]=None, + voc_ckpt: Optional[os.PathLike]=None, + voc_stat: Optional[os.PathLike]=None, + lang: str='zh', ): + """ + Init model and other resources from a specific path. + """ + if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'): + logger.info('Models had been initialized.') + return + # am model info + am_tag = am + '-' + lang + if am_ckpt is None or am_config is None or am_stat is None or phones_dict is None: + am_res_path = self._get_pretrained_path(am_tag) + self.am_res_path = am_res_path + self.am_config = os.path.join(am_res_path, + pretrained_models[am_tag]['config']) + self.am_ckpt = os.path.join(am_res_path, + pretrained_models[am_tag]['ckpt']) + self.am_stat = os.path.join( + am_res_path, pretrained_models[am_tag]['speech_stats']) + # must have phones_dict in acoustic + self.phones_dict = os.path.join( + am_res_path, pretrained_models[am_tag]['phones_dict']) + print("self.phones_dict:", self.phones_dict) + logger.info(am_res_path) + logger.info(self.am_config) + logger.info(self.am_ckpt) + else: + self.am_config = os.path.abspath(am_config) + self.am_ckpt = os.path.abspath(am_ckpt) + self.am_stat = os.path.abspath(am_stat) + self.phones_dict = os.path.abspath(phones_dict) + self.am_res_path = os.path.dirname(os.path.abspath(self.am_config)) + print("self.phones_dict:", self.phones_dict) + + self.tones_dict = None + self.speaker_dict = None + + # voc model info + voc_tag = voc + '-' + lang + if voc_ckpt is None or voc_config is None or voc_stat is None: + voc_res_path = self._get_pretrained_path(voc_tag) + self.voc_res_path = voc_res_path + self.voc_config = os.path.join(voc_res_path, + pretrained_models[voc_tag]['config']) + self.voc_ckpt = os.path.join(voc_res_path, + pretrained_models[voc_tag]['ckpt']) + self.voc_stat = os.path.join( + voc_res_path, pretrained_models[voc_tag]['speech_stats']) + logger.info(voc_res_path) + logger.info(self.voc_config) + logger.info(self.voc_ckpt) + else: + self.voc_config = os.path.abspath(voc_config) + self.voc_ckpt = os.path.abspath(voc_ckpt) + self.voc_stat = os.path.abspath(voc_stat) + self.voc_res_path = os.path.dirname( + os.path.abspath(self.voc_config)) + + # Init body. + with open(self.am_config) as f: + self.am_config = CfgNode(yaml.safe_load(f)) + with open(self.voc_config) as f: + self.voc_config = CfgNode(yaml.safe_load(f)) + + with open(self.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + self.vocab_size = len(phn_id) + print("vocab_size:", self.vocab_size) + + # frontend + if lang == 'zh': + self.frontend = Frontend( + phone_vocab_path=self.phones_dict, + tone_vocab_path=self.tones_dict) + + elif lang == 'en': + self.frontend = English(phone_vocab_path=self.phones_dict) + print("frontend done!") + + # am infer info + self.am_name = am[:am.rindex('_')] + if self.am_name == "fastspeech2_cnndecoder": + self.am_inference, self.am_mu, self.am_std = self.get_model_info( + "am", "fastspeech2", self.am_ckpt, self.am_stat) + else: + am, am_mu, am_std = self.get_model_info("am", self.am_name, + self.am_ckpt, self.am_stat) + am_normalizer = ZScore(am_mu, am_std) + am_inference_class = dynamic_import(self.am_name + '_inference', + model_alias) + self.am_inference = am_inference_class(am_normalizer, am) + self.am_inference.eval() + print("acoustic model done!") + + # voc infer info + self.voc_name = voc[:voc.rindex('_')] + voc, voc_mu, voc_std = self.get_model_info("voc", self.voc_name, + self.voc_ckpt, self.voc_stat) + voc_normalizer = ZScore(voc_mu, voc_std) + voc_inference_class = dynamic_import(self.voc_name + '_inference', + model_alias) + self.voc_inference = voc_inference_class(voc_normalizer, voc) + self.voc_inference.eval() + print("voc done!") + + def get_phone(self, sentence, lang, merge_sentences, get_tone_ids): + tone_ids = None + if lang == 'zh': + input_ids = self.frontend.get_input_ids( + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) + phone_ids = input_ids["phone_ids"] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] + elif lang == 'en': + input_ids = self.frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + + def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): + """ + Streaming inference removes the result of pad inference + """ + front_pad = min(chunk_id * block, pad) + # first chunk + if chunk_id == 0: + data = data[:block * upsample] + # last chunk + elif chunk_id == chunk_num - 1: + data = data[front_pad * upsample:] + # middle chunk + else: + data = data[front_pad * upsample:(front_pad + block) * upsample] + + return data @paddle.no_grad() def infer( @@ -37,16 +342,20 @@ class TTSServerExecutor(TTSExecutor): text: str, lang: str='zh', am: str='fastspeech2_csmsc', - spk_id: int=0, - am_block: int=42, - am_pad: int=12, - voc_block: int=14, - voc_pad: int=14, ): + spk_id: int=0, ): """ Model inference and result stored in self.output. """ - am_name = am[:am.rindex('_')] - am_dataset = am[am.rindex('_') + 1:] + + am_block = self.am_block + am_pad = self.am_pad + am_upsample = 1 + voc_block = self.voc_block + voc_pad = self.voc_pad + voc_upsample = self.voc_config.n_shift + # first_flag 用于标记首包 + first_flag = 1 + get_tone_ids = False merge_sentences = False frontend_st = time.time() @@ -64,43 +373,100 @@ class TTSServerExecutor(TTSExecutor): phone_ids = input_ids["phone_ids"] else: print("lang should in {'zh', 'en'}!") - self.frontend_time = time.time() - frontend_st + frontend_et = time.time() + self.frontend_time = frontend_et - frontend_st for i in range(len(phone_ids)): - am_st = time.time() part_phone_ids = phone_ids[i] - # am - if am_name == 'speedyspeech': - part_tone_ids = tone_ids[i] - mel = self.am_inference(part_phone_ids, part_tone_ids) - # fastspeech2 + voc_chunk_id = 0 + + # fastspeech2_csmsc + if am == "fastspeech2_csmsc": + # am + mel = self.am_inference(part_phone_ids) + if first_flag == 1: + first_am_et = time.time() + self.first_am_infer = first_am_et - frontend_et + + # voc streaming + mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") + voc_chunk_num = len(mel_chunks) + voc_st = time.time() + for i, mel_chunk in enumerate(mel_chunks): + sub_wav = self.voc_inference(mel_chunk) + sub_wav = self.depadding(sub_wav, voc_chunk_num, i, + voc_block, voc_pad, voc_upsample) + if first_flag == 1: + first_voc_et = time.time() + self.first_voc_infer = first_voc_et - first_am_et + self.first_response_time = first_voc_et - frontend_st + first_flag = 0 + + yield sub_wav + + # fastspeech2_cnndecoder_csmsc + elif am == "fastspeech2_cnndecoder_csmsc": + # am + orig_hs, h_masks = self.am_inference.encoder_infer( + part_phone_ids) + + # streaming voc chunk info + mel_len = orig_hs.shape[1] + voc_chunk_num = math.ceil(mel_len / self.voc_block) + start = 0 + end = min(self.voc_block + self.voc_pad, mel_len) + + # streaming am + hss = get_chunks(orig_hs, self.am_block, self.am_pad, "am") + am_chunk_num = len(hss) + for i, hs in enumerate(hss): + before_outs, _ = self.am_inference.decoder(hs) + after_outs = before_outs + self.am_inference.postnet( + before_outs.transpose((0, 2, 1))).transpose((0, 2, 1)) + normalized_mel = after_outs[0] + sub_mel = denorm(normalized_mel, self.am_mu, self.am_std) + sub_mel = self.depadding(sub_mel, am_chunk_num, i, am_block, + am_pad, am_upsample) + + if i == 0: + mel_streaming = sub_mel + else: + mel_streaming = np.concatenate( + (mel_streaming, sub_mel), axis=0) + + # streaming voc + # 当流式AM推理的mel帧数大于流式voc推理的chunk size,开始进行流式voc 推理 + while (mel_streaming.shape[0] >= end and + voc_chunk_id < voc_chunk_num): + if first_flag == 1: + first_am_et = time.time() + self.first_am_infer = first_am_et - frontend_et + voc_chunk = mel_streaming[start:end, :] + voc_chunk = paddle.to_tensor(voc_chunk) + sub_wav = self.voc_inference(voc_chunk) + + sub_wav = self.depadding(sub_wav, voc_chunk_num, + voc_chunk_id, voc_block, + voc_pad, voc_upsample) + if first_flag == 1: + first_voc_et = time.time() + self.first_voc_infer = first_voc_et - first_am_et + self.first_response_time = first_voc_et - frontend_st + first_flag = 0 + + yield sub_wav + + voc_chunk_id += 1 + start = max(0, voc_chunk_id * voc_block - voc_pad) + end = min((voc_chunk_id + 1) * voc_block + voc_pad, + mel_len) + else: - # multi speaker - if am_dataset in {"aishell3", "vctk"}: - mel = self.am_inference( - part_phone_ids, spk_id=paddle.to_tensor(spk_id)) - else: - mel = self.am_inference(part_phone_ids) - am_et = time.time() - - # voc streaming - voc_upsample = self.voc_config.n_shift - mel_chunks = get_chunks(mel, voc_block, voc_pad, "voc") - chunk_num = len(mel_chunks) - voc_st = time.time() - for i, mel_chunk in enumerate(mel_chunks): - sub_wav = self.voc_inference(mel_chunk) - front_pad = min(i * voc_block, voc_pad) - - if i == 0: - sub_wav = sub_wav[:voc_block * voc_upsample] - elif i == chunk_num - 1: - sub_wav = sub_wav[front_pad * voc_upsample:] - else: - sub_wav = sub_wav[front_pad * voc_upsample:( - front_pad + voc_block) * voc_upsample] - - yield sub_wav + logger.error( + "Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts." + ) + + self.final_response_time = time.time() - frontend_st class TTSEngine(BaseEngine): @@ -113,14 +479,21 @@ class TTSEngine(BaseEngine): def __init__(self, name=None): """Initialize TTS server engine """ - super(TTSEngine, self).__init__() + super().__init__() def init(self, config: dict) -> bool: - self.executor = TTSServerExecutor() self.config = config - assert "fastspeech2_csmsc" in config.am and ( - config.voc == "hifigan_csmsc-zh" or config.voc == "mb_melgan_csmsc" + assert ( + config.am == "fastspeech2_csmsc" or + config.am == "fastspeech2_cnndecoder_csmsc" + ) and ( + config.voc == "hifigan_csmsc" or config.voc == "mb_melgan_csmsc" ), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.' + + assert ( + config.voc_block > 0 and config.voc_pad > 0 + ), "Please set correct voc_block and voc_pad, they should be more than 0." + try: if self.config.device: self.device = self.config.device @@ -135,6 +508,9 @@ class TTSEngine(BaseEngine): (self.device)) return False + self.executor = TTSServerExecutor(config.am_block, config.am_pad, + config.voc_block, config.voc_pad) + try: self.executor._init_from_path( am=self.config.am, @@ -155,15 +531,42 @@ class TTSEngine(BaseEngine): (self.device)) return False - self.am_block = self.config.am_block - self.am_pad = self.config.am_pad - self.voc_block = self.config.voc_block - self.voc_pad = self.config.voc_pad - logger.info("Initialize TTS server engine successfully on device: %s." % (self.device)) + + # warm up + try: + self.warm_up() + except Exception as e: + logger.error("Failed to warm up on tts engine.") + return False + return True + def warm_up(self): + """warm up + """ + if self.config.lang == 'zh': + sentence = "您好,欢迎使用语音合成服务。" + if self.config.lang == 'en': + sentence = "Hello and welcome to the speech synthesis service." + logger.info( + "*******************************warm up ********************************" + ) + for i in range(3): + for wav in self.executor.infer( + text=sentence, + lang=self.config.lang, + am=self.config.am, + spk_id=0, ): + logger.info( + f"The first response time of the {i} warm up: {self.executor.first_response_time} s" + ) + break + logger.info( + "**********************************************************************" + ) + def preprocess(self, text_bese64: str=None, text_bytes: bytes=None): # Convert byte to text if text_bese64: @@ -195,18 +598,14 @@ class TTSEngine(BaseEngine): wav_base64: The base64 format of the synthesized audio. """ - lang = self.config.lang wav_list = [] for wav in self.executor.infer( text=sentence, - lang=lang, + lang=self.config.lang, am=self.config.am, - spk_id=spk_id, - am_block=self.am_block, - am_pad=self.am_pad, - voc_block=self.voc_block, - voc_pad=self.voc_pad): + spk_id=spk_id, ): + # wav type: float32, convert to pcm (base64) wav = float2pcm(wav) # float32 to int16 wav_bytes = wav.tobytes() # to bytes @@ -216,5 +615,14 @@ class TTSEngine(BaseEngine): yield wav_base64 wav_all = np.concatenate(wav_list, axis=0) - logger.info("The durations of audio is: {} s".format( - len(wav_all) / self.executor.am_config.fs)) + duration = len(wav_all) / self.executor.am_config.fs + logger.info(f"sentence: {sentence}") + logger.info(f"The durations of audio is: {duration} s") + logger.info( + f"first response time: {self.executor.first_response_time} s") + logger.info( + f"final response time: {self.executor.final_response_time} s") + logger.info(f"RTF: {self.executor.final_response_time / duration}") + logger.info( + f"Other info: front time: {self.executor.frontend_time} s, first am infer time: {self.executor.first_am_infer} s, first voc infer time: {self.executor.first_voc_infer} s," + ) diff --git a/paddlespeech/server/tests/__init__.py b/paddlespeech/server/tests/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/__init__.py b/paddlespeech/server/tests/asr/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/offline/__init__.py b/paddlespeech/server/tests/asr/offline/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/offline/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/online/__init__.py b/paddlespeech/server/tests/asr/online/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/server/tests/asr/online/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/server/tests/asr/online/websocket_client.py b/paddlespeech/server/tests/asr/online/websocket_client.py index 01b19405..49cbd703 100644 --- a/paddlespeech/server/tests/asr/online/websocket_client.py +++ b/paddlespeech/server/tests/asr/online/websocket_client.py @@ -34,10 +34,9 @@ class ASRAudioHandler: def read_wave(self, wavfile_path: str): samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') x_len = len(samples) - # chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz - chunk_size = 80 * 16 #80ms, sample_rate = 16kHz - if x_len % chunk_size != 0: + chunk_size = 85 * 16 #80ms, sample_rate = 16kHz + if x_len % chunk_size!= 0: padding_len_x = chunk_size - x_len % chunk_size else: padding_len_x = 0 @@ -48,7 +47,6 @@ class ASRAudioHandler: assert (x_len + padding_len_x) % chunk_size == 0 num_chunk = (x_len + padding_len_x) / chunk_size num_chunk = int(num_chunk) - for i in range(0, num_chunk): start = i * chunk_size end = start + chunk_size @@ -57,7 +55,11 @@ class ASRAudioHandler: async def run(self, wavfile_path: str): logging.info("send a message to the server") + # self.read_wave() + # send websocket handshake protocal async with websockets.connect(self.url) as ws: + # server has already received handshake protocal + # client start to send the command audio_info = json.dumps( { "name": "test.wav", @@ -78,7 +80,6 @@ class ASRAudioHandler: msg = json.loads(msg) logging.info("receive msg={}".format(msg)) - result = msg # finished audio_info = json.dumps( { @@ -91,10 +92,12 @@ class ASRAudioHandler: separators=(',', ': ')) await ws.send(audio_info) msg = await ws.recv() + + # decode the bytes to str msg = json.loads(msg) - logging.info("receive msg={}".format(msg)) - - return result + logging.info("final receive msg={}".format(msg)) + result = msg + return result def main(args): diff --git a/paddlespeech/server/utils/buffer.py b/paddlespeech/server/utils/buffer.py index 12b1f0e5..d4e6cd49 100644 --- a/paddlespeech/server/utils/buffer.py +++ b/paddlespeech/server/utils/buffer.py @@ -63,12 +63,12 @@ class ChunkBuffer(object): the sample rate. Yields Frames of the requested duration. """ + audio = self.remained_audio + audio self.remained_audio = b'' offset = 0 timestamp = 0.0 - while offset + self.window_bytes <= len(audio): yield Frame(audio[offset:offset + self.window_bytes], timestamp, self.window_sec) diff --git a/paddlespeech/server/utils/util.py b/paddlespeech/server/utils/util.py index 0fe70849..72ee0060 100644 --- a/paddlespeech/server/utils/util.py +++ b/paddlespeech/server/utils/util.py @@ -52,6 +52,10 @@ def get_chunks(data, block_size, pad_size, step): Returns: list: chunks list """ + + if block_size == -1: + return [data] + if step == "am": data_len = data.shape[1] elif step == "voc": diff --git a/paddlespeech/server/ws/asr_socket.py b/paddlespeech/server/ws/asr_socket.py index 03a49b48..a865703d 100644 --- a/paddlespeech/server/ws/asr_socket.py +++ b/paddlespeech/server/ws/asr_socket.py @@ -13,12 +13,12 @@ # limitations under the License. import json -import numpy as np from fastapi import APIRouter from fastapi import WebSocket from fastapi import WebSocketDisconnect from starlette.websockets import WebSocketState as WebSocketState +from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler from paddlespeech.server.engine.engine_pool import get_engine_pool from paddlespeech.server.utils.buffer import ChunkBuffer from paddlespeech.server.utils.vad import VADAudio @@ -28,26 +28,29 @@ router = APIRouter() @router.websocket('/ws/asr') async def websocket_endpoint(websocket: WebSocket): - await websocket.accept() engine_pool = get_engine_pool() asr_engine = engine_pool['asr'] + connection_handler = None # init buffer + # each websocekt connection has its own chunk buffer chunk_buffer_conf = asr_engine.config.chunk_buffer_conf chunk_buffer = ChunkBuffer( - window_n=7, - shift_n=4, - window_ms=20, - shift_ms=10, - sample_rate=chunk_buffer_conf['sample_rate'], - sample_width=chunk_buffer_conf['sample_width']) + window_n=chunk_buffer_conf.window_n, + shift_n=chunk_buffer_conf.shift_n, + window_ms=chunk_buffer_conf.window_ms, + shift_ms=chunk_buffer_conf.shift_ms, + sample_rate=chunk_buffer_conf.sample_rate, + sample_width=chunk_buffer_conf.sample_width) + # init vad - vad_conf = asr_engine.config.vad_conf - vad = VADAudio( - aggressiveness=vad_conf['aggressiveness'], - rate=vad_conf['sample_rate'], - frame_duration_ms=vad_conf['frame_duration_ms']) + vad_conf = asr_engine.config.get('vad_conf', None) + if vad_conf: + vad = VADAudio( + aggressiveness=vad_conf['aggressiveness'], + rate=vad_conf['sample_rate'], + frame_duration_ms=vad_conf['frame_duration_ms']) try: while True: @@ -64,13 +67,21 @@ async def websocket_endpoint(websocket: WebSocket): if message['signal'] == 'start': resp = {"status": "ok", "signal": "server_ready"} # do something at begining here + # create the instance to process the audio + connection_handler = PaddleASRConnectionHanddler(asr_engine) await websocket.send_json(resp) elif message['signal'] == 'end': - engine_pool = get_engine_pool() - asr_engine = engine_pool['asr'] # reset single engine for an new connection - asr_engine.reset() - resp = {"status": "ok", "signal": "finished"} + connection_handler.decode(is_finished=True) + connection_handler.rescoring() + asr_results = connection_handler.get_result() + connection_handler.reset() + + resp = { + "status": "ok", + "signal": "finished", + 'asr_results': asr_results + } await websocket.send_json(resp) break else: @@ -79,21 +90,11 @@ async def websocket_endpoint(websocket: WebSocket): elif "bytes" in message: message = message["bytes"] - engine_pool = get_engine_pool() - asr_engine = engine_pool['asr'] - asr_results = "" - frames = chunk_buffer.frame_generator(message) - for frame in frames: - samples = np.frombuffer(frame.bytes, dtype=np.int16) - sample_rate = asr_engine.config.sample_rate - x_chunk, x_chunk_lens = asr_engine.preprocess(samples, - sample_rate) - asr_engine.run(x_chunk, x_chunk_lens) - asr_results = asr_engine.postprocess() + connection_handler.extract_feat(message) + connection_handler.decode(is_finished=False) + asr_results = connection_handler.get_result() - asr_results = asr_engine.postprocess() resp = {'asr_results': asr_results} - await websocket.send_json(resp) except WebSocketDisconnect: pass diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 3e7c11f2..98e73e10 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -14,6 +14,7 @@ import argparse from pathlib import Path +import paddle import soundfile as sf from timer import timer @@ -101,21 +102,35 @@ def parse_args(): # only inference for models trained with csmsc now def main(): args = parse_args() + + paddle.set_device(args.device) + # frontend - frontend = get_frontend(args) + frontend = get_frontend( + lang=args.lang, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict) # am_predictor - am_predictor = get_predictor(args, filed='am') + am_predictor = get_predictor( + model_dir=args.inference_dir, + model_file=args.am + ".pdmodel", + params_file=args.am + ".pdiparams", + device=args.device) # model: {model_name}_{dataset} am_dataset = args.am[args.am.rindex('_') + 1:] # voc_predictor - voc_predictor = get_predictor(args, filed='voc') + voc_predictor = get_predictor( + model_dir=args.inference_dir, + model_file=args.voc + ".pdmodel", + params_file=args.voc + ".pdiparams", + device=args.device) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - sentences = get_sentences(args) + sentences = get_sentences(text_file=args.text, lang=args.lang) merge_sentences = True fs = 24000 if am_dataset != 'ljspeech' else 22050 @@ -123,11 +138,13 @@ def main(): for utt_id, sentence in sentences[:3]: with timer() as t: am_output_data = get_am_output( - args, + input=sentence, am_predictor=am_predictor, + am=args.am, frontend=frontend, + lang=args.lang, merge_sentences=merge_sentences, - input=sentence) + speaker_dict=args.speaker_dict, ) wav = get_voc_output( voc_predictor=voc_predictor, input=am_output_data) speed = wav.size / t.elapse @@ -143,11 +160,13 @@ def main(): for utt_id, sentence in sentences: with timer() as t: am_output_data = get_am_output( - args, + input=sentence, am_predictor=am_predictor, + am=args.am, frontend=frontend, + lang=args.lang, merge_sentences=merge_sentences, - input=sentence) + speaker_dict=args.speaker_dict, ) wav = get_voc_output( voc_predictor=voc_predictor, input=am_output_data) diff --git a/paddlespeech/t2s/exps/inference_streaming.py b/paddlespeech/t2s/exps/inference_streaming.py index 0e58056c..b680f19a 100644 --- a/paddlespeech/t2s/exps/inference_streaming.py +++ b/paddlespeech/t2s/exps/inference_streaming.py @@ -15,6 +15,7 @@ import argparse from pathlib import Path import numpy as np +import paddle import soundfile as sf from timer import timer @@ -25,7 +26,6 @@ from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_predictor from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_streaming_am_output -from paddlespeech.t2s.exps.syn_utils import get_streaming_am_predictor from paddlespeech.t2s.exps.syn_utils import get_voc_output from paddlespeech.t2s.utils import str2bool @@ -101,23 +101,47 @@ def parse_args(): # only inference for models trained with csmsc now def main(): args = parse_args() + + paddle.set_device(args.device) + # frontend - frontend = get_frontend(args) + frontend = get_frontend( + lang=args.lang, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict) # am_predictor - am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor = get_streaming_am_predictor( - args) + + am_encoder_infer_predictor = get_predictor( + model_dir=args.inference_dir, + model_file=args.am + "_am_encoder_infer" + ".pdmodel", + params_file=args.am + "_am_encoder_infer" + ".pdiparams", + device=args.device) + am_decoder_predictor = get_predictor( + model_dir=args.inference_dir, + model_file=args.am + "_am_decoder" + ".pdmodel", + params_file=args.am + "_am_decoder" + ".pdiparams", + device=args.device) + am_postnet_predictor = get_predictor( + model_dir=args.inference_dir, + model_file=args.am + "_am_postnet" + ".pdmodel", + params_file=args.am + "_am_postnet" + ".pdiparams", + device=args.device) am_mu, am_std = np.load(args.am_stat) # model: {model_name}_{dataset} am_dataset = args.am[args.am.rindex('_') + 1:] # voc_predictor - voc_predictor = get_predictor(args, filed='voc') + voc_predictor = get_predictor( + model_dir=args.inference_dir, + model_file=args.voc + ".pdmodel", + params_file=args.voc + ".pdiparams", + device=args.device) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - sentences = get_sentences(args) + sentences = get_sentences(text_file=args.text, lang=args.lang) merge_sentences = True @@ -126,13 +150,13 @@ def main(): for utt_id, sentence in sentences[:3]: with timer() as t: normalized_mel = get_streaming_am_output( - args, + input=sentence, am_encoder_infer_predictor=am_encoder_infer_predictor, am_decoder_predictor=am_decoder_predictor, am_postnet_predictor=am_postnet_predictor, frontend=frontend, - merge_sentences=merge_sentences, - input=sentence) + lang=args.lang, + merge_sentences=merge_sentences, ) mel = denorm(normalized_mel, am_mu, am_std) wav = get_voc_output(voc_predictor=voc_predictor, input=mel) speed = wav.size / t.elapse diff --git a/paddlespeech/t2s/exps/ort_predict.py b/paddlespeech/t2s/exps/ort_predict.py index d1f03710..2e8596de 100644 --- a/paddlespeech/t2s/exps/ort_predict.py +++ b/paddlespeech/t2s/exps/ort_predict.py @@ -16,6 +16,7 @@ from pathlib import Path import jsonlines import numpy as np +import paddle import soundfile as sf from timer import timer @@ -25,12 +26,13 @@ from paddlespeech.t2s.utils import str2bool def ort_predict(args): + # construct dataset for evaluation with jsonlines.open(args.test_metadata, 'r') as reader: test_metadata = list(reader) am_name = args.am[:args.am.rindex('_')] am_dataset = args.am[args.am.rindex('_') + 1:] - test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset) + test_dataset = get_test_dataset(test_metadata=test_metadata, am=args.am) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -38,10 +40,18 @@ def ort_predict(args): fs = 24000 if am_dataset != 'ljspeech' else 22050 # am - am_sess = get_sess(args, filed='am') + am_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.am + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) # vocoder - voc_sess = get_sess(args, filed='voc') + voc_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.voc + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) # am warmup for T in [27, 38, 54]: @@ -135,6 +145,8 @@ def parse_args(): def main(): args = parse_args() + paddle.set_device(args.device) + ort_predict(args) diff --git a/paddlespeech/t2s/exps/ort_predict_e2e.py b/paddlespeech/t2s/exps/ort_predict_e2e.py index 366a2902..a2ef8e4c 100644 --- a/paddlespeech/t2s/exps/ort_predict_e2e.py +++ b/paddlespeech/t2s/exps/ort_predict_e2e.py @@ -15,6 +15,7 @@ import argparse from pathlib import Path import numpy as np +import paddle import soundfile as sf from timer import timer @@ -27,21 +28,31 @@ from paddlespeech.t2s.utils import str2bool def ort_predict(args): # frontend - frontend = get_frontend(args) + frontend = get_frontend( + lang=args.lang, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - sentences = get_sentences(args) + sentences = get_sentences(text_file=args.text, lang=args.lang) am_name = args.am[:args.am.rindex('_')] am_dataset = args.am[args.am.rindex('_') + 1:] fs = 24000 if am_dataset != 'ljspeech' else 22050 - # am - am_sess = get_sess(args, filed='am') + am_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.am + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) # vocoder - voc_sess = get_sess(args, filed='voc') + voc_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.voc + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) # frontend warmup # Loading model cost 0.5+ seconds @@ -168,6 +179,8 @@ def parse_args(): def main(): args = parse_args() + paddle.set_device(args.device) + ort_predict(args) diff --git a/paddlespeech/t2s/exps/ort_predict_streaming.py b/paddlespeech/t2s/exps/ort_predict_streaming.py index 1b486d19..5d2c66bc 100644 --- a/paddlespeech/t2s/exps/ort_predict_streaming.py +++ b/paddlespeech/t2s/exps/ort_predict_streaming.py @@ -15,6 +15,7 @@ import argparse from pathlib import Path import numpy as np +import paddle import soundfile as sf from timer import timer @@ -23,30 +24,50 @@ from paddlespeech.t2s.exps.syn_utils import get_chunks from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_sentences from paddlespeech.t2s.exps.syn_utils import get_sess -from paddlespeech.t2s.exps.syn_utils import get_streaming_am_sess from paddlespeech.t2s.utils import str2bool def ort_predict(args): # frontend - frontend = get_frontend(args) + frontend = get_frontend( + lang=args.lang, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - sentences = get_sentences(args) + sentences = get_sentences(text_file=args.text, lang=args.lang) am_name = args.am[:args.am.rindex('_')] am_dataset = args.am[args.am.rindex('_') + 1:] fs = 24000 if am_dataset != 'ljspeech' else 22050 - # am - am_encoder_infer_sess, am_decoder_sess, am_postnet_sess = get_streaming_am_sess( - args) + # streaming acoustic model + am_encoder_infer_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.am + "_am_encoder_infer" + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) + am_decoder_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.am + "_am_decoder" + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) + + am_postnet_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.am + "_am_postnet" + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) am_mu, am_std = np.load(args.am_stat) # vocoder - voc_sess = get_sess(args, filed='voc') + voc_sess = get_sess( + model_dir=args.inference_dir, + model_file=args.voc + ".onnx", + device=args.device, + cpu_threads=args.cpu_threads) # frontend warmup # Loading model cost 0.5+ seconds @@ -226,6 +247,8 @@ def parse_args(): def main(): args = parse_args() + paddle.set_device(args.device) + ort_predict(args) diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index 21aa5bf8..ce0aee05 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -14,6 +14,10 @@ import math import os from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Optional import numpy as np import onnxruntime as ort @@ -21,6 +25,7 @@ import paddle from paddle import inference from paddle import jit from paddle.static import InputSpec +from yacs.config import CfgNode from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.t2s.datasets.data_table import DataTable @@ -70,7 +75,7 @@ def denorm(data, mean, std): return data * std + mean -def get_chunks(data, chunk_size, pad_size): +def get_chunks(data, chunk_size: int, pad_size: int): data_len = data.shape[1] chunks = [] n = math.ceil(data_len / chunk_size) @@ -82,28 +87,34 @@ def get_chunks(data, chunk_size, pad_size): # input -def get_sentences(args): +def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'): # construct dataset for evaluation sentences = [] - with open(args.text, 'rt') as f: + with open(text_file, 'rt') as f: for line in f: items = line.strip().split() utt_id = items[0] - if 'lang' in args and args.lang == 'zh': + if lang == 'zh': sentence = "".join(items[1:]) - elif 'lang' in args and args.lang == 'en': + elif lang == 'en': sentence = " ".join(items[1:]) sentences.append((utt_id, sentence)) return sentences -def get_test_dataset(args, test_metadata, am_name, am_dataset): +def get_test_dataset(test_metadata: List[Dict[str, Any]], + am: str, + speaker_dict: Optional[os.PathLike]=None, + voice_cloning: bool=False): + # model: {model_name}_{dataset} + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] if am_name == 'fastspeech2': fields = ["utt_id", "text"] - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: print("multiple speaker fastspeech2!") fields += ["spk_id"] - elif 'voice_cloning' in args and args.voice_cloning: + elif voice_cloning: print("voice cloning!") fields += ["spk_emb"] else: @@ -112,7 +123,7 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset): fields = ["utt_id", "phones", "tones"] elif am_name == 'tacotron2': fields = ["utt_id", "text"] - if 'voice_cloning' in args and args.voice_cloning: + if voice_cloning: print("voice cloning!") fields += ["spk_emb"] @@ -121,12 +132,14 @@ def get_test_dataset(args, test_metadata, am_name, am_dataset): # frontend -def get_frontend(args): - if 'lang' in args and args.lang == 'zh': +def get_frontend(lang: str='zh', + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None): + if lang == 'zh': frontend = Frontend( - phone_vocab_path=args.phones_dict, tone_vocab_path=args.tones_dict) - elif 'lang' in args and args.lang == 'en': - frontend = English(phone_vocab_path=args.phones_dict) + phone_vocab_path=phones_dict, tone_vocab_path=tones_dict) + elif lang == 'en': + frontend = English(phone_vocab_path=phones_dict) else: print("wrong lang!") print("frontend done!") @@ -134,30 +147,37 @@ def get_frontend(args): # dygraph -def get_am_inference(args, am_config): - with open(args.phones_dict, "r") as f: +def get_am_inference( + am: str='fastspeech2_csmsc', + am_config: CfgNode=None, + am_ckpt: Optional[os.PathLike]=None, + am_stat: Optional[os.PathLike]=None, + phones_dict: Optional[os.PathLike]=None, + tones_dict: Optional[os.PathLike]=None, + speaker_dict: Optional[os.PathLike]=None, ): + with open(phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] vocab_size = len(phn_id) print("vocab_size:", vocab_size) tone_size = None - if 'tones_dict' in args and args.tones_dict: - with open(args.tones_dict, "r") as f: + if tones_dict is not None: + with open(tones_dict, "r") as f: tone_id = [line.strip().split() for line in f.readlines()] tone_size = len(tone_id) print("tone_size:", tone_size) spk_num = None - if 'speaker_dict' in args and args.speaker_dict: - with open(args.speaker_dict, 'rt') as f: + if speaker_dict is not None: + with open(speaker_dict, 'rt') as f: spk_id = [line.strip().split() for line in f.readlines()] spk_num = len(spk_id) print("spk_num:", spk_num) odim = am_config.n_mels # model: {model_name}_{dataset} - am_name = args.am[:args.am.rindex('_')] - am_dataset = args.am[args.am.rindex('_') + 1:] + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] am_class = dynamic_import(am_name, model_alias) am_inference_class = dynamic_import(am_name + '_inference', model_alias) @@ -174,34 +194,38 @@ def get_am_inference(args, am_config): elif am_name == 'tacotron2': am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) - am.set_state_dict(paddle.load(args.am_ckpt)["main_params"]) + am.set_state_dict(paddle.load(am_ckpt)["main_params"]) am.eval() - am_mu, am_std = np.load(args.am_stat) + am_mu, am_std = np.load(am_stat) am_mu = paddle.to_tensor(am_mu) am_std = paddle.to_tensor(am_std) am_normalizer = ZScore(am_mu, am_std) am_inference = am_inference_class(am_normalizer, am) am_inference.eval() print("acoustic model done!") - return am_inference, am_name, am_dataset + return am_inference -def get_voc_inference(args, voc_config): +def get_voc_inference( + voc: str='pwgan_csmsc', + voc_config: Optional[os.PathLike]=None, + voc_ckpt: Optional[os.PathLike]=None, + voc_stat: Optional[os.PathLike]=None, ): # model: {model_name}_{dataset} - voc_name = args.voc[:args.voc.rindex('_')] + voc_name = voc[:voc.rindex('_')] voc_class = dynamic_import(voc_name, model_alias) voc_inference_class = dynamic_import(voc_name + '_inference', model_alias) if voc_name != 'wavernn': voc = voc_class(**voc_config["generator_params"]) - voc.set_state_dict(paddle.load(args.voc_ckpt)["generator_params"]) + voc.set_state_dict(paddle.load(voc_ckpt)["generator_params"]) voc.remove_weight_norm() voc.eval() else: voc = voc_class(**voc_config["model"]) - voc.set_state_dict(paddle.load(args.voc_ckpt)["main_params"]) + voc.set_state_dict(paddle.load(voc_ckpt)["main_params"]) voc.eval() - voc_mu, voc_std = np.load(args.voc_stat) + voc_mu, voc_std = np.load(voc_stat) voc_mu = paddle.to_tensor(voc_mu) voc_std = paddle.to_tensor(voc_std) voc_normalizer = ZScore(voc_mu, voc_std) @@ -211,10 +235,16 @@ def get_voc_inference(args, voc_config): return voc_inference -# to static -def am_to_static(args, am_inference, am_name, am_dataset): +# dygraph to static graph +def am_to_static(am_inference, + am: str='fastspeech2_csmsc', + inference_dir=Optional[os.PathLike], + speaker_dict: Optional[os.PathLike]=None): + # model: {model_name}_{dataset} + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] if am_name == 'fastspeech2': - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: am_inference = jit.to_static( am_inference, input_spec=[ @@ -226,7 +256,7 @@ def am_to_static(args, am_inference, am_name, am_dataset): am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) elif am_name == 'speedyspeech': - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: am_inference = jit.to_static( am_inference, input_spec=[ @@ -247,56 +277,64 @@ def am_to_static(args, am_inference, am_name, am_dataset): am_inference = jit.to_static( am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) - paddle.jit.save(am_inference, os.path.join(args.inference_dir, args.am)) - am_inference = paddle.jit.load(os.path.join(args.inference_dir, args.am)) + paddle.jit.save(am_inference, os.path.join(inference_dir, am)) + am_inference = paddle.jit.load(os.path.join(inference_dir, am)) return am_inference -def voc_to_static(args, voc_inference): +def voc_to_static(voc_inference, + voc: str='pwgan_csmsc', + inference_dir=Optional[os.PathLike]): voc_inference = jit.to_static( voc_inference, input_spec=[ InputSpec([-1, 80], dtype=paddle.float32), ]) - paddle.jit.save(voc_inference, os.path.join(args.inference_dir, args.voc)) - voc_inference = paddle.jit.load(os.path.join(args.inference_dir, args.voc)) + paddle.jit.save(voc_inference, os.path.join(inference_dir, voc)) + voc_inference = paddle.jit.load(os.path.join(inference_dir, voc)) return voc_inference # inference -def get_predictor(args, filed='am'): - full_name = '' - if filed == 'am': - full_name = args.am - elif filed == 'voc': - full_name = args.voc +def get_predictor(model_dir: Optional[os.PathLike]=None, + model_file: Optional[os.PathLike]=None, + params_file: Optional[os.PathLike]=None, + device: str='cpu'): + config = inference.Config( - str(Path(args.inference_dir) / (full_name + ".pdmodel")), - str(Path(args.inference_dir) / (full_name + ".pdiparams"))) - if args.device == "gpu": + str(Path(model_dir) / model_file), str(Path(model_dir) / params_file)) + if device == "gpu": config.enable_use_gpu(100, 0) - elif args.device == "cpu": + elif device == "cpu": config.disable_gpu() config.enable_memory_optim() predictor = inference.create_predictor(config) return predictor -def get_am_output(args, am_predictor, frontend, merge_sentences, input): - am_name = args.am[:args.am.rindex('_')] - am_dataset = args.am[args.am.rindex('_') + 1:] +def get_am_output( + input: str, + am_predictor, + am, + frontend, + lang: str='zh', + merge_sentences: bool=True, + speaker_dict: Optional[os.PathLike]=None, + spk_id: int=0, ): + am_name = am[:am.rindex('_')] + am_dataset = am[am.rindex('_') + 1:] am_input_names = am_predictor.get_input_names() get_tone_ids = False get_spk_id = False if am_name == 'speedyspeech': get_tone_ids = True - if am_dataset in {"aishell3", "vctk"} and args.speaker_dict: + if am_dataset in {"aishell3", "vctk"} and speaker_dict: get_spk_id = True - spk_id = np.array([args.spk_id]) - if args.lang == 'zh': + spk_id = np.array([spk_id]) + if lang == 'zh': input_ids = frontend.get_input_ids( input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] - elif args.lang == 'en': + elif lang == 'en': input_ids = frontend.get_input_ids( input, merge_sentences=merge_sentences) phone_ids = input_ids["phone_ids"] @@ -338,50 +376,6 @@ def get_voc_output(voc_predictor, input): return wav -# streaming am -def get_streaming_am_predictor(args): - full_name = args.am - am_encoder_infer_config = inference.Config( - str( - Path(args.inference_dir) / - (full_name + "_am_encoder_infer" + ".pdmodel")), - str( - Path(args.inference_dir) / - (full_name + "_am_encoder_infer" + ".pdiparams"))) - am_decoder_config = inference.Config( - str( - Path(args.inference_dir) / - (full_name + "_am_decoder" + ".pdmodel")), - str( - Path(args.inference_dir) / - (full_name + "_am_decoder" + ".pdiparams"))) - am_postnet_config = inference.Config( - str( - Path(args.inference_dir) / - (full_name + "_am_postnet" + ".pdmodel")), - str( - Path(args.inference_dir) / - (full_name + "_am_postnet" + ".pdiparams"))) - if args.device == "gpu": - am_encoder_infer_config.enable_use_gpu(100, 0) - am_decoder_config.enable_use_gpu(100, 0) - am_postnet_config.enable_use_gpu(100, 0) - elif args.device == "cpu": - am_encoder_infer_config.disable_gpu() - am_decoder_config.disable_gpu() - am_postnet_config.disable_gpu() - - am_encoder_infer_config.enable_memory_optim() - am_decoder_config.enable_memory_optim() - am_postnet_config.enable_memory_optim() - - am_encoder_infer_predictor = inference.create_predictor( - am_encoder_infer_config) - am_decoder_predictor = inference.create_predictor(am_decoder_config) - am_postnet_predictor = inference.create_predictor(am_postnet_config) - return am_encoder_infer_predictor, am_decoder_predictor, am_postnet_predictor - - def get_am_sublayer_output(am_sublayer_predictor, input): am_sublayer_input_names = am_sublayer_predictor.get_input_names() input_handle = am_sublayer_predictor.get_input_handle( @@ -397,11 +391,15 @@ def get_am_sublayer_output(am_sublayer_predictor, input): return am_sublayer_output -def get_streaming_am_output(args, am_encoder_infer_predictor, - am_decoder_predictor, am_postnet_predictor, - frontend, merge_sentences, input): +def get_streaming_am_output(input: str, + am_encoder_infer_predictor, + am_decoder_predictor, + am_postnet_predictor, + frontend, + lang: str='zh', + merge_sentences: bool=True): get_tone_ids = False - if args.lang == 'zh': + if lang == 'zh': input_ids = frontend.get_input_ids( input, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] @@ -423,58 +421,27 @@ def get_streaming_am_output(args, am_encoder_infer_predictor, return normalized_mel -def get_sess(args, filed='am'): - full_name = '' - if filed == 'am': - full_name = args.am - elif filed == 'voc': - full_name = args.voc - model_dir = str(Path(args.inference_dir) / (full_name + ".onnx")) +# onnx +def get_sess(model_dir: Optional[os.PathLike]=None, + model_file: Optional[os.PathLike]=None, + device: str='cpu', + cpu_threads: int=1, + use_trt: bool=False): + + model_dir = str(Path(model_dir) / model_file) sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - if args.device == "gpu": + if device == "gpu": # fastspeech2/mb_melgan can't use trt now! - if args.use_trt: + if use_trt: providers = ['TensorrtExecutionProvider'] else: providers = ['CUDAExecutionProvider'] - elif args.device == "cpu": + elif device == "cpu": providers = ['CPUExecutionProvider'] - sess_options.intra_op_num_threads = args.cpu_threads + sess_options.intra_op_num_threads = cpu_threads sess = ort.InferenceSession( model_dir, providers=providers, sess_options=sess_options) return sess - - -# streaming am -def get_streaming_am_sess(args): - full_name = args.am - am_encoder_infer_model_dir = str( - Path(args.inference_dir) / (full_name + "_am_encoder_infer" + ".onnx")) - am_decoder_model_dir = str( - Path(args.inference_dir) / (full_name + "_am_decoder" + ".onnx")) - am_postnet_model_dir = str( - Path(args.inference_dir) / (full_name + "_am_postnet" + ".onnx")) - sess_options = ort.SessionOptions() - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL - if args.device == "gpu": - # fastspeech2/mb_melgan can't use trt now! - if args.use_trt: - providers = ['TensorrtExecutionProvider'] - else: - providers = ['CUDAExecutionProvider'] - elif args.device == "cpu": - providers = ['CPUExecutionProvider'] - sess_options.intra_op_num_threads = args.cpu_threads - am_encoder_infer_sess = ort.InferenceSession( - am_encoder_infer_model_dir, - providers=providers, - sess_options=sess_options) - am_decoder_sess = ort.InferenceSession( - am_decoder_model_dir, providers=providers, sess_options=sess_options) - am_postnet_sess = ort.InferenceSession( - am_postnet_model_dir, providers=providers, sess_options=sess_options) - return am_encoder_infer_sess, am_decoder_sess, am_postnet_sess diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py index abb1eb4e..dd66e54e 100644 --- a/paddlespeech/t2s/exps/synthesize.py +++ b/paddlespeech/t2s/exps/synthesize.py @@ -50,11 +50,29 @@ def evaluate(args): print(voc_config) # acoustic model - am_inference, am_name, am_dataset = get_am_inference(args, am_config) - test_dataset = get_test_dataset(args, test_metadata, am_name, am_dataset) + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + + am_inference = get_am_inference( + am=args.am, + am_config=am_config, + am_ckpt=args.am_ckpt, + am_stat=args.am_stat, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict, + speaker_dict=args.speaker_dict) + test_dataset = get_test_dataset( + test_metadata=test_metadata, + am=args.am, + speaker_dict=args.speaker_dict, + voice_cloning=args.voice_cloning) # vocoder - voc_inference = get_voc_inference(args, voc_config) + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 6c28dc48..2f14ef56 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -42,24 +42,48 @@ def evaluate(args): print(am_config) print(voc_config) - sentences = get_sentences(args) + sentences = get_sentences(text_file=args.text, lang=args.lang) # frontend - frontend = get_frontend(args) + frontend = get_frontend( + lang=args.lang, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict) # acoustic model - am_inference, am_name, am_dataset = get_am_inference(args, am_config) + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + + am_inference = get_am_inference( + am=args.am, + am_config=am_config, + am_ckpt=args.am_ckpt, + am_stat=args.am_stat, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict, + speaker_dict=args.speaker_dict) # vocoder - voc_inference = get_voc_inference(args, voc_config) + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) # whether dygraph to static if args.inference_dir: # acoustic model - am_inference = am_to_static(args, am_inference, am_name, am_dataset) + am_inference = am_to_static( + am_inference=am_inference, + am=args.am, + inference_dir=args.inference_dir, + speaker_dict=args.speaker_dict) # vocoder - voc_inference = voc_to_static(args, voc_inference) + voc_inference = voc_to_static( + voc_inference=voc_inference, + voc=args.voc, + inference_dir=args.inference_dir) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/paddlespeech/t2s/exps/synthesize_streaming.py b/paddlespeech/t2s/exps/synthesize_streaming.py index 4f7a84e9..3659cb49 100644 --- a/paddlespeech/t2s/exps/synthesize_streaming.py +++ b/paddlespeech/t2s/exps/synthesize_streaming.py @@ -49,10 +49,13 @@ def evaluate(args): print(am_config) print(voc_config) - sentences = get_sentences(args) + sentences = get_sentences(text_file=args.text, lang=args.lang) # frontend - frontend = get_frontend(args) + frontend = get_frontend( + lang=args.lang, + phones_dict=args.phones_dict, + tones_dict=args.tones_dict) with open(args.phones_dict, "r") as f: phn_id = [line.strip().split() for line in f.readlines()] @@ -60,7 +63,6 @@ def evaluate(args): print("vocab_size:", vocab_size) # acoustic model, only support fastspeech2 here now! - # am_inference, am_name, am_dataset = get_am_inference(args, am_config) # model: {model_name}_{dataset} am_name = args.am[:args.am.rindex('_')] am_dataset = args.am[args.am.rindex('_') + 1:] @@ -80,7 +82,11 @@ def evaluate(args): am_postnet = am.postnet # vocoder - voc_inference = get_voc_inference(args, voc_config) + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) # whether dygraph to static if args.inference_dir: @@ -115,7 +121,10 @@ def evaluate(args): os.path.join(args.inference_dir, args.am + "_am_postnet")) # vocoder - voc_inference = voc_to_static(args, voc_inference) + voc_inference = voc_to_static( + voc_inference=voc_inference, + voc=args.voc, + inference_dir=args.inference_dir) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/paddlespeech/t2s/exps/voice_cloning.py b/paddlespeech/t2s/exps/voice_cloning.py index 1afd21df..9257b07d 100644 --- a/paddlespeech/t2s/exps/voice_cloning.py +++ b/paddlespeech/t2s/exps/voice_cloning.py @@ -66,10 +66,19 @@ def voice_cloning(args): print("frontend done!") # acoustic model - am_inference, *_ = get_am_inference(args, am_config) + am_inference = get_am_inference( + am=args.am, + am_config=am_config, + am_ckpt=args.am_ckpt, + am_stat=args.am_stat, + phones_dict=args.phones_dict) # vocoder - voc_inference = get_voc_inference(args, voc_config) + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/paddlespeech/t2s/exps/wavernn/synthesize.py b/paddlespeech/t2s/exps/wavernn/synthesize.py index d23e9cb7..ea48a617 100644 --- a/paddlespeech/t2s/exps/wavernn/synthesize.py +++ b/paddlespeech/t2s/exps/wavernn/synthesize.py @@ -58,8 +58,7 @@ def main(): else: print("ngpu should >= 0 !") - model = WaveRNN( - hop_length=config.n_shift, sample_rate=config.fs, **config["model"]) + model = WaveRNN(**config["model"]) state_dict = paddle.load(args.checkpoint) model.set_state_dict(state_dict["main_params"]) diff --git a/paddlespeech/vector/modules/loss.py b/paddlespeech/vector/modules/loss.py index 1c80dda4..9a7530c1 100644 --- a/paddlespeech/vector/modules/loss.py +++ b/paddlespeech/vector/modules/loss.py @@ -91,3 +91,199 @@ class LogSoftmaxWrapper(nn.Layer): predictions = F.log_softmax(predictions, axis=1) loss = self.criterion(predictions, targets) / targets.sum() return loss + + +class NCELoss(nn.Layer): + """Noise Contrastive Estimation loss funtion + + Noise Contrastive Estimation (NCE) is an approximation method that is used to + work around the huge computational cost of large softmax layer. + The basic idea is to convert the prediction problem into classification problem + at training stage. It has been proved that these two criterions converges to + the same minimal point as long as noise distribution is close enough to real one. + + NCE bridges the gap between generative models and discriminative models, + rather than simply speedup the softmax layer. + With NCE, you can turn almost anything into posterior with less effort (I think). + + Refs: + NCE:http://www.cs.helsinki.fi/u/ahyvarin/papers/Gutmann10AISTATS.pdf + Thanks: https://github.com/mingen-pan/easy-to-use-NCE-RNN-for-Pytorch/blob/master/nce.py + + Examples: + Q = Q_from_tokens(output_dim) + NCELoss(Q) + """ + + def __init__(self, Q, noise_ratio=100, Z_offset=9.5): + """Noise Contrastive Estimation loss funtion + + Args: + Q (tensor): prior model, uniform or guassian + noise_ratio (int, optional): noise sampling times. Defaults to 100. + Z_offset (float, optional): scale of post processing the score. Defaults to 9.5. + """ + super(NCELoss, self).__init__() + assert type(noise_ratio) is int + self.Q = paddle.to_tensor(Q, stop_gradient=False) + self.N = self.Q.shape[0] + self.K = noise_ratio + self.Z_offset = Z_offset + + def forward(self, output, target): + """Forward inference + + Args: + output (tensor): the model output, which is the input of loss function + """ + output = paddle.reshape(output, [-1, self.N]) + B = output.shape[0] + noise_idx = self.get_noise(B) + idx = self.get_combined_idx(target, noise_idx) + P_target, P_noise = self.get_prob(idx, output, sep_target=True) + Q_target, Q_noise = self.get_Q(idx) + loss = self.nce_loss(P_target, P_noise, Q_noise, Q_target) + return loss.mean() + + def get_Q(self, idx, sep_target=True): + """Get prior model of batchsize data + """ + idx_size = idx.size + prob_model = paddle.to_tensor( + self.Q.numpy()[paddle.reshape(idx, [-1]).numpy()]) + prob_model = paddle.reshape(prob_model, [idx.shape[0], idx.shape[1]]) + if sep_target: + return prob_model[:, 0], prob_model[:, 1:] + else: + return prob_model + + def get_prob(self, idx, scores, sep_target=True): + """Post processing the score of post model(output of nn) of batchsize data + """ + scores = self.get_scores(idx, scores) + scale = paddle.to_tensor([self.Z_offset], dtype='float64') + scores = paddle.add(scores, -scale) + prob = paddle.exp(scores) + if sep_target: + return prob[:, 0], prob[:, 1:] + else: + return prob + + def get_scores(self, idx, scores): + """Get the score of post model(output of nn) of batchsize data + """ + B, N = scores.shape + K = idx.shape[1] + idx_increment = paddle.to_tensor( + N * paddle.reshape(paddle.arange(B), [B, 1]) * paddle.ones([1, K]), + dtype="int64", + stop_gradient=False) + new_idx = idx_increment + idx + new_scores = paddle.index_select( + paddle.reshape(scores, [-1]), paddle.reshape(new_idx, [-1])) + + return paddle.reshape(new_scores, [B, K]) + + def get_noise(self, batch_size, uniform=True): + """Select noise sample + """ + if uniform: + noise = np.random.randint(self.N, size=self.K * batch_size) + else: + noise = np.random.choice( + self.N, self.K * batch_size, replace=True, p=self.Q.data) + noise = paddle.to_tensor(noise, dtype='int64', stop_gradient=False) + noise_idx = paddle.reshape(noise, [batch_size, self.K]) + return noise_idx + + def get_combined_idx(self, target_idx, noise_idx): + """Combined target and noise + """ + target_idx = paddle.reshape(target_idx, [-1, 1]) + return paddle.concat((target_idx, noise_idx), 1) + + def nce_loss(self, prob_model, prob_noise_in_model, prob_noise, + prob_target_in_noise): + """Combined the loss of target and noise + """ + + def safe_log(tensor): + """Safe log + """ + EPSILON = 1e-10 + return paddle.log(EPSILON + tensor) + + model_loss = safe_log(prob_model / + (prob_model + self.K * prob_target_in_noise)) + model_loss = paddle.reshape(model_loss, [-1]) + + noise_loss = paddle.sum( + safe_log((self.K * prob_noise) / + (prob_noise_in_model + self.K * prob_noise)), -1) + noise_loss = paddle.reshape(noise_loss, [-1]) + + loss = -(model_loss + noise_loss) + + return loss + + +class FocalLoss(nn.Layer): + """This criterion is a implemenation of Focal Loss, which is proposed in + Focal Loss for Dense Object Detection. + + Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) + + The losses are averaged across observations for each minibatch. + + Args: + alpha(1D Tensor, Variable) : the scalar factor for this criterion + gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), + putting more focus on hard, misclassified examples + size_average(bool): By default, the losses are averaged over observations for each minibatch. + However, if the field size_average is set to False, the losses are + instead summed for each minibatch. + """ + + def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=-100): + super(FocalLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.size_average = size_average + self.ce = nn.CrossEntropyLoss( + ignore_index=ignore_index, reduction="none") + + def forward(self, outputs, targets): + """Forword inference. + + Args: + outputs: input tensor + target: target label tensor + """ + ce_loss = self.ce(outputs, targets) + pt = paddle.exp(-ce_loss) + focal_loss = self.alpha * (1 - pt)**self.gamma * ce_loss + if self.size_average: + return focal_loss.mean() + else: + return focal_loss.sum() + + +if __name__ == "__main__": + import numpy as np + from paddlespeech.vector.utils.vector_utils import Q_from_tokens + paddle.set_device("cpu") + + input_data = paddle.uniform([5, 100], dtype="float64") + label_data = np.random.randint(0, 100, size=(5)).astype(np.int64) + + input = paddle.to_tensor(input_data) + label = paddle.to_tensor(label_data) + + loss1 = FocalLoss() + loss = loss1.forward(input, label) + print("loss: %.5f" % (loss)) + + Q = Q_from_tokens(100) + loss2 = NCELoss(Q) + loss = loss2.forward(input, label) + print("loss: %.5f" % (loss)) diff --git a/paddlespeech/vector/utils/vector_utils.py b/paddlespeech/vector/utils/vector_utils.py index 46de7ffa..d6659e3f 100644 --- a/paddlespeech/vector/utils/vector_utils.py +++ b/paddlespeech/vector/utils/vector_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import paddle def get_chunks(seg_dur, audio_id, audio_duration): @@ -30,3 +31,11 @@ def get_chunks(seg_dur, audio_id, audio_duration): for i in range(num_chunks) ] return chunk_lst + + +def Q_from_tokens(token_num): + """Get prior model, data from uniform, would support others(guassian) in future + """ + freq = [1] * token_num + Q = paddle.to_tensor(freq, dtype='float64') + return Q / Q.sum() diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index f1330d1d..98d9e637 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -63,7 +63,8 @@ include(libsndfile) # include(boost) # not work set(boost_SOURCE_DIR ${fc_patch}/boost-src) set(BOOST_ROOT ${boost_SOURCE_DIR}) -# #find_package(boost REQUIRED PATHS ${BOOST_ROOT}) +include_directories(${boost_SOURCE_DIR}) +link_directories(${boost_SOURCE_DIR}/stage/lib) # Eigen include(eigen) @@ -141,4 +142,4 @@ set(DEPS ${DEPS} set(SPEECHX_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/speechx) add_subdirectory(speechx) -add_subdirectory(examples) \ No newline at end of file +add_subdirectory(examples) diff --git a/speechx/examples/ds2_ol/CMakeLists.txt b/speechx/examples/ds2_ol/CMakeLists.txt index 89cbd0ef..08c19484 100644 --- a/speechx/examples/ds2_ol/CMakeLists.txt +++ b/speechx/examples/ds2_ol/CMakeLists.txt @@ -2,4 +2,5 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) add_subdirectory(feat) add_subdirectory(nnet) -add_subdirectory(decoder) \ No newline at end of file +add_subdirectory(decoder) +add_subdirectory(websocket) diff --git a/speechx/examples/ds2_ol/aishell/path.sh b/speechx/examples/ds2_ol/aishell/path.sh index 0a300f36..520129ea 100644 --- a/speechx/examples/ds2_ol/aishell/path.sh +++ b/speechx/examples/ds2_ol/aishell/path.sh @@ -1,6 +1,6 @@ # This contains the locations of binarys build required for running the examples. -SPEECHX_ROOT=$PWD/../../../ +SPEECHX_ROOT=$PWD/../../.. SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_TOOLS=$SPEECHX_ROOT/tools @@ -10,5 +10,5 @@ TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin export LC_AL=C -SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat -export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN \ No newline at end of file +SPEECHX_BIN=$SPEECHX_EXAMPLES/ds2_ol/decoder:$SPEECHX_EXAMPLES/ds2_ol/feat:$SPEECHX_EXAMPLES/ds2_ol/websocket +export PATH=$PATH:$SPEECHX_BIN:$TOOLS_BIN diff --git a/speechx/examples/ds2_ol/aishell/run.sh b/speechx/examples/ds2_ol/aishell/run.sh index a61785dc..779123d5 100755 --- a/speechx/examples/ds2_ol/aishell/run.sh +++ b/speechx/examples/ds2_ol/aishell/run.sh @@ -87,7 +87,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ctc-prefix-beam-search-decoder-ol \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ + --params_path=$model_dir/avg_1.jit.pdiparams \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --dict_file=$vocb_dir/vocab.txt \ --result_wspecifier=ark,t:$data/split${nj}/JOB/result @@ -102,7 +102,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then ctc-prefix-beam-search-decoder-ol \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ + --params_path=$model_dir/avg_1.jit.pdiparams \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --dict_file=$vocb_dir/vocab.txt \ --lm_path=$lm \ @@ -129,7 +129,7 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then wfst-decoder-ol \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --model_path=$model_dir/avg_1.jit.pdmodel \ - --param_path=$model_dir/avg_1.jit.pdiparams \ + --params_path=$model_dir/avg_1.jit.pdiparams \ --word_symbol_table=$graph_dir/words.txt \ --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ --graph_path=$graph_dir/TLG.fst --max_active=7500 \ diff --git a/speechx/examples/ds2_ol/aishell/websocket_client.sh b/speechx/examples/ds2_ol/aishell/websocket_client.sh new file mode 100644 index 00000000..3c6b4e91 --- /dev/null +++ b/speechx/examples/ds2_ol/aishell/websocket_client.sh @@ -0,0 +1,37 @@ +#!/bin/bash +set +x +set -e + +. path.sh + +# 1. compile +if [ ! -d ${SPEECHX_EXAMPLES} ]; then + pushd ${SPEECHX_ROOT} + bash build.sh + popd +fi + +# input +mkdir -p data +data=$PWD/data +ckpt_dir=$data/model +model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ +vocb_dir=$ckpt_dir/data/lang_char +# output +aishell_wav_scp=aishell_test.scp +if [ ! -d $data/test ]; then + pushd $data + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip + unzip aishell_test.zip + popd + + realpath $data/test/*/*.wav > $data/wavlist + awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id + paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp +fi + +export GLOG_logtostderr=1 + +# websocket client +websocket_client_main \ + --wav_rspecifier=scp:$data/$aishell_wav_scp --streaming_chunk=0.36 diff --git a/speechx/examples/ds2_ol/aishell/websocket_server.sh b/speechx/examples/ds2_ol/aishell/websocket_server.sh new file mode 100644 index 00000000..ea619d54 --- /dev/null +++ b/speechx/examples/ds2_ol/aishell/websocket_server.sh @@ -0,0 +1,66 @@ +#!/bin/bash +set +x +set -e + +. path.sh + + +# 1. compile +if [ ! -d ${SPEECHX_EXAMPLES} ]; then + pushd ${SPEECHX_ROOT} + bash build.sh + popd +fi + +# input +mkdir -p data +data=$PWD/data +ckpt_dir=$data/model +model_dir=$ckpt_dir/exp/deepspeech2_online/checkpoints/ +vocb_dir=$ckpt_dir/data/lang_char/ + +# output +aishell_wav_scp=aishell_test.scp +if [ ! -d $data/test ]; then + pushd $data + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip + unzip aishell_test.zip + popd + + realpath $data/test/*/*.wav > $data/wavlist + awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id + paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp +fi + + +if [ ! -d $ckpt_dir ]; then + mkdir -p $ckpt_dir + wget -P $ckpt_dir -c https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz + tar xzfv $ckpt_dir/asr0_deepspeech2_online_aishell_ckpt_0.2.0.model.tar.gz -C $ckpt_dir +fi + + +export GLOG_logtostderr=1 + +# 3. gen cmvn +cmvn=$PWD/cmvn.ark +cmvn-json2kaldi --json_file=$ckpt_dir/data/mean_std.json --cmvn_write_path=$cmvn + +text=$data/test/text +graph_dir=./aishell_graph +if [ ! -d $graph_dir ]; then + wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_graph.zip + unzip aishell_graph.zip +fi + +# 5. test websocket server +websocket_server_main \ + --cmvn_file=$cmvn \ + --model_path=$model_dir/avg_1.jit.pdmodel \ + --streaming_chunk=0.1 \ + --convert2PCM32=true \ + --params_path=$model_dir/avg_1.jit.pdiparams \ + --word_symbol_table=$graph_dir/words.txt \ + --model_output_names=softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0 \ + --graph_path=$graph_dir/TLG.fst --max_active=7500 \ + --acoustic_scale=1.2 diff --git a/speechx/examples/ds2_ol/decoder/CMakeLists.txt b/speechx/examples/ds2_ol/decoder/CMakeLists.txt index 6139ebfa..62dd6862 100644 --- a/speechx/examples/ds2_ol/decoder/CMakeLists.txt +++ b/speechx/examples/ds2_ol/decoder/CMakeLists.txt @@ -17,3 +17,6 @@ add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) +add_executable(recognizer_test_main ${CMAKE_CURRENT_SOURCE_DIR}/recognizer_test_main.cc) +target_include_directories(recognizer_test_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(recognizer_test_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder ${DEPS}) diff --git a/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc b/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc index 49d64b69..e145f6ee 100644 --- a/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc +++ b/speechx/examples/ds2_ol/decoder/ctc-prefix-beam-search-decoder-ol.cc @@ -34,12 +34,10 @@ DEFINE_int32(receptive_field_length, DEFINE_int32(downsampling_rate, 4, "two CNN(kernel=5) module downsampling rate."); -DEFINE_string( - model_input_names, - "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", - "model input names"); DEFINE_string(model_output_names, - "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", + "save_infer_model/scale_0.tmp_1,save_infer_model/" + "scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/" + "scale_3.tmp_1", "model output names"); DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); @@ -58,12 +56,11 @@ int main(int argc, char* argv[]) { kaldi::SequentialBaseFloatMatrixReader feature_reader( FLAGS_feature_rspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); - - std::string model_graph = FLAGS_model_path; + std::string model_path = FLAGS_model_path; std::string model_params = FLAGS_param_path; std::string dict_file = FLAGS_dict_file; std::string lm_path = FLAGS_lm_path; - LOG(INFO) << "model path: " << model_graph; + LOG(INFO) << "model path: " << model_path; LOG(INFO) << "model param: " << model_params; LOG(INFO) << "dict path: " << dict_file; LOG(INFO) << "lm path: " << lm_path; @@ -76,10 +73,9 @@ int main(int argc, char* argv[]) { ppspeech::CTCBeamSearch decoder(opts); ppspeech::ModelOptions model_opts; - model_opts.model_path = model_graph; + model_opts.model_path = model_path; model_opts.params_path = model_params; model_opts.cache_shape = FLAGS_model_cache_names; - model_opts.input_names = FLAGS_model_input_names; model_opts.output_names = FLAGS_model_output_names; std::shared_ptr nnet( new ppspeech::PaddleNnet(model_opts)); @@ -125,7 +121,6 @@ int main(int argc, char* argv[]) { if (feature_chunk_size < receptive_field_length) break; int32 start = chunk_idx * chunk_stride; - int32 end = start + chunk_size; for (int row_id = 0; row_id < chunk_size; ++row_id) { kaldi::SubVector tmp(feature, start); diff --git a/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc b/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc new file mode 100644 index 00000000..198a8ec2 --- /dev/null +++ b/speechx/examples/ds2_ol/decoder/recognizer_test_main.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/recognizer.h" +#include "decoder/param.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/table-types.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure(); + ppspeech::Recognizer recognizer(resource); + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + int sample_rate = 16000; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + int32 num_done = 0, num_err = 0; + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + std::vector> feats; + int feature_rows = 0; + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + kaldi::Vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk(i) = waveform(sample_offset + i); + } + + recognizer.Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + recognizer.SetFinished(); + } + recognizer.Decode(); + + sample_offset += cur_chunk_size; + } + std::string result; + result = recognizer.GetFinalResult(); + recognizer.Reset(); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + KALDI_LOG << " the result of " << utt << " is empty"; + continue; + } + KALDI_LOG << " the result of " << utt << " is " << result; + result_writer.Write(utt, result); + ++num_done; + } +} \ No newline at end of file diff --git a/speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc b/speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc index b8385664..0a9cfb06 100644 --- a/speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc +++ b/speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc @@ -73,9 +73,9 @@ int main(int argc, char* argv[]) { LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path; LOG(INFO) << "Binary: " << FLAGS_binary; } catch (simdjson::simdjson_error& err) { - LOG(ERR) << err.what(); + LOG(ERROR) << err.what(); } return 0; -} \ No newline at end of file +} diff --git a/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc b/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc index 27ca6f9f..0d10bd30 100644 --- a/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc +++ b/speechx/examples/ds2_ol/feat/linear-spectrogram-wo-db-norm-ol.cc @@ -32,7 +32,6 @@ DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); - int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); @@ -66,7 +65,8 @@ int main(int argc, char* argv[]) { std::unique_ptr cmvn( new ppspeech::CMVN(FLAGS_cmvn_file, std::move(linear_spectrogram))); - ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); + ppspeech::FeatureCacheOptions feat_cache_opts; + ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); LOG(INFO) << "feat dim: " << feature_cache.Dim(); int sample_rate = 16000; diff --git a/speechx/examples/ds2_ol/websocket/CMakeLists.txt b/speechx/examples/ds2_ol/websocket/CMakeLists.txt new file mode 100644 index 00000000..754b528e --- /dev/null +++ b/speechx/examples/ds2_ol/websocket/CMakeLists.txt @@ -0,0 +1,10 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +add_executable(websocket_server_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_server_main.cc) +target_include_directories(websocket_server_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(websocket_server_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS}) + +add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc) +target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) +target_link_libraries(websocket_client_main PUBLIC frontend kaldi-feat-common nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util kaldi-decoder websocket ${DEPS}) + diff --git a/speechx/examples/ds2_ol/websocket/websocket_client_main.cc b/speechx/examples/ds2_ol/websocket/websocket_client_main.cc new file mode 100644 index 00000000..68ea898a --- /dev/null +++ b/speechx/examples/ds2_ol/websocket/websocket_client_main.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "websocket/websocket_client.h" +#include "kaldi/feat/wave-reader.h" +#include "kaldi/util/kaldi-io.h" +#include "kaldi/util/table-types.h" + +DEFINE_string(host, "127.0.0.1", "host of websocket server"); +DEFINE_int32(port, 201314, "port of websocket server"); +DEFINE_string(wav_rspecifier, "", "test wav scp path"); +DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size"); + +using kaldi::int16; +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + ppspeech::WebSocketClient client(FLAGS_host, FLAGS_port); + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + + const int sample_rate = 16000; + const float streaming_chunk = FLAGS_streaming_chunk; + const int chunk_sample_size = streaming_chunk * sample_rate; + + for (; !wav_reader.Done(); wav_reader.Next()) { + client.SendStartSignal(); + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + CHECK_EQ(wave_data.SampFreq(), sample_rate); + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + const int tot_samples = waveform.Dim(); + int sample_offset = 0; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = static_cast(waveform(sample_offset + i)); + } + client.SendBinaryData(wav_chunk.data(), + wav_chunk.size() * sizeof(int16)); + + + sample_offset += cur_chunk_size; + LOG(INFO) << "Send " << cur_chunk_size << " samples"; + std::this_thread::sleep_for( + std::chrono::milliseconds(static_cast(1 * 1000))); + + if (cur_chunk_size < chunk_sample_size) { + client.SendEndSignal(); + } + } + + while (!client.Done()) { + } + std::string result = client.GetResult(); + LOG(INFO) << "utt: " << utt << " " << result; + + + client.Join(); + return 0; + } + return 0; +} diff --git a/speechx/examples/ds2_ol/websocket/websocket_server_main.cc b/speechx/examples/ds2_ol/websocket/websocket_server_main.cc new file mode 100644 index 00000000..43cbd6bb --- /dev/null +++ b/speechx/examples/ds2_ol/websocket/websocket_server_main.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "websocket/websocket_server.h" +#include "decoder/param.h" + +DEFINE_int32(port, 201314, "websocket listening port"); + +int main(int argc, char *argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + ppspeech::RecognizerResource resource = ppspeech::InitRecognizerResoure(); + + ppspeech::WebSocketServer server(FLAGS_port, resource); + LOG(INFO) << "Listening at port " << FLAGS_port; + server.Start(); + return 0; +} diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index 225abee7..b4da095d 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -30,4 +30,10 @@ include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/decoder ) -add_subdirectory(decoder) \ No newline at end of file +add_subdirectory(decoder) + +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/websocket +) +add_subdirectory(websocket) diff --git a/speechx/speechx/base/common.h b/speechx/speechx/base/common.h index 7502bc5e..a9303cbb 100644 --- a/speechx/speechx/base/common.h +++ b/speechx/speechx/base/common.h @@ -28,8 +28,10 @@ #include #include #include +#include #include #include +#include #include #include "base/basic_types.h" diff --git a/speechx/speechx/decoder/CMakeLists.txt b/speechx/speechx/decoder/CMakeLists.txt index ee0863fd..06bf4020 100644 --- a/speechx/speechx/decoder/CMakeLists.txt +++ b/speechx/speechx/decoder/CMakeLists.txt @@ -7,5 +7,6 @@ add_library(decoder STATIC ctc_decoders/path_trie.cpp ctc_decoders/scorer.cpp ctc_tlg_decoder.cc + recognizer.cc ) -target_link_libraries(decoder PUBLIC kenlm utils fst) +target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/speechx/speechx/decoder/ctc_tlg_decoder.cc index 5365e709..7b720e7b 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.cc +++ b/speechx/speechx/decoder/ctc_tlg_decoder.cc @@ -33,7 +33,6 @@ void TLGDecoder::InitDecoder() { void TLGDecoder::AdvanceDecode( const std::shared_ptr& decodable) { while (!decodable->IsLastFrame(frame_decoded_size_)) { - LOG(INFO) << "num frame decode: " << frame_decoded_size_; AdvanceDecoding(decodable.get()); } } @@ -63,4 +62,4 @@ std::string TLGDecoder::GetFinalBestPath() { } return words; } -} \ No newline at end of file +} diff --git a/speechx/speechx/decoder/param.h b/speechx/speechx/decoder/param.h new file mode 100644 index 00000000..cd50ef53 --- /dev/null +++ b/speechx/speechx/decoder/param.h @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" + +#include "decoder/ctc_beam_search_decoder.h" +#include "decoder/ctc_tlg_decoder.h" +#include "frontend/audio/feature_pipeline.h" + +DEFINE_string(cmvn_file, "", "read cmvn"); +DEFINE_double(streaming_chunk, 0.1, "streaming feature chunk size"); +DEFINE_bool(convert2PCM32, true, "audio convert to pcm32"); +DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); +DEFINE_string(params_path, "avg_1.jit.pdiparams", "paddle nnet model param"); +DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); +DEFINE_string(graph_path, "TLG", "decoder graph"); +DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); +DEFINE_int32(max_active, 7500, "max active"); +DEFINE_double(beam, 15.0, "decoder beam"); +DEFINE_double(lattice_beam, 7.5, "decoder beam"); +DEFINE_int32(receptive_field_length, + 7, + "receptive field of two CNN(kernel=5) downsampling module."); +DEFINE_int32(downsampling_rate, + 4, + "two CNN(kernel=5) module downsampling rate."); +DEFINE_string(model_output_names, + "save_infer_model/scale_0.tmp_1,save_infer_model/" + "scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/" + "scale_3.tmp_1", + "model output names"); +DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names"); + +namespace ppspeech { +// todo refactor later +FeaturePipelineOptions InitFeaturePipelineOptions() { + FeaturePipelineOptions opts; + opts.cmvn_file = FLAGS_cmvn_file; + opts.linear_spectrogram_opts.streaming_chunk = FLAGS_streaming_chunk; + opts.convert2PCM32 = FLAGS_convert2PCM32; + kaldi::FrameExtractionOptions frame_opts; + frame_opts.frame_length_ms = 20; + frame_opts.frame_shift_ms = 10; + frame_opts.remove_dc_offset = false; + frame_opts.window_type = "hanning"; + frame_opts.preemph_coeff = 0.0; + frame_opts.dither = 0.0; + opts.linear_spectrogram_opts.frame_opts = frame_opts; + opts.feature_cache_opts.frame_chunk_size = FLAGS_receptive_field_length; + opts.feature_cache_opts.frame_chunk_stride = FLAGS_downsampling_rate; + return opts; +} + +ModelOptions InitModelOptions() { + ModelOptions model_opts; + model_opts.model_path = FLAGS_model_path; + model_opts.params_path = FLAGS_params_path; + model_opts.cache_shape = FLAGS_model_cache_names; + model_opts.output_names = FLAGS_model_output_names; + return model_opts; +} + +TLGDecoderOptions InitDecoderOptions() { + TLGDecoderOptions decoder_opts; + decoder_opts.word_symbol_table = FLAGS_word_symbol_table; + decoder_opts.fst_path = FLAGS_graph_path; + decoder_opts.opts.max_active = FLAGS_max_active; + decoder_opts.opts.beam = FLAGS_beam; + decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; + return decoder_opts; +} + +RecognizerResource InitRecognizerResoure() { + RecognizerResource resource; + resource.acoustic_scale = FLAGS_acoustic_scale; + resource.feature_pipeline_opts = InitFeaturePipelineOptions(); + resource.model_opts = InitModelOptions(); + resource.tlg_opts = InitDecoderOptions(); + return resource; +} +} \ No newline at end of file diff --git a/speechx/speechx/decoder/recognizer.cc b/speechx/speechx/decoder/recognizer.cc new file mode 100644 index 00000000..2c90ada9 --- /dev/null +++ b/speechx/speechx/decoder/recognizer.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decoder/recognizer.h" + +namespace ppspeech { + +using kaldi::Vector; +using kaldi::VectorBase; +using kaldi::BaseFloat; +using std::vector; +using kaldi::SubVector; +using std::unique_ptr; + +Recognizer::Recognizer(const RecognizerResource& resource) { + // resource_ = resource; + const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; + feature_pipeline_.reset(new FeaturePipeline(feature_opts)); + std::shared_ptr nnet(new PaddleNnet(resource.model_opts)); + BaseFloat ac_scale = resource.acoustic_scale; + decodable_.reset(new Decodable(nnet, feature_pipeline_, ac_scale)); + decoder_.reset(new TLGDecoder(resource.tlg_opts)); + input_finished_ = false; +} + +void Recognizer::Accept(const Vector& waves) { + feature_pipeline_->Accept(waves); +} + +void Recognizer::Decode() { decoder_->AdvanceDecode(decodable_); } + +std::string Recognizer::GetFinalResult() { + return decoder_->GetFinalBestPath(); +} + +void Recognizer::SetFinished() { + feature_pipeline_->SetFinished(); + input_finished_ = true; +} + +bool Recognizer::IsFinished() { return input_finished_; } + +void Recognizer::Reset() { + feature_pipeline_->Reset(); + decodable_->Reset(); + decoder_->Reset(); +} + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/decoder/recognizer.h b/speechx/speechx/decoder/recognizer.h new file mode 100644 index 00000000..9a7e7d11 --- /dev/null +++ b/speechx/speechx/decoder/recognizer.h @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor later (SGoat) + +#pragma once + +#include "decoder/ctc_beam_search_decoder.h" +#include "decoder/ctc_tlg_decoder.h" +#include "frontend/audio/feature_pipeline.h" +#include "nnet/decodable.h" +#include "nnet/paddle_nnet.h" + +namespace ppspeech { + +struct RecognizerResource { + FeaturePipelineOptions feature_pipeline_opts; + ModelOptions model_opts; + TLGDecoderOptions tlg_opts; + // CTCBeamSearchOptions beam_search_opts; + kaldi::BaseFloat acoustic_scale; + RecognizerResource() + : acoustic_scale(1.0), + feature_pipeline_opts(), + model_opts(), + tlg_opts() {} +}; + +class Recognizer { + public: + explicit Recognizer(const RecognizerResource& resouce); + void Accept(const kaldi::Vector& waves); + void Decode(); + std::string GetFinalResult(); + void SetFinished(); + bool IsFinished(); + void Reset(); + + private: + // std::shared_ptr resource_; + // RecognizerResource resource_; + std::shared_ptr feature_pipeline_; + std::shared_ptr decodable_; + std::unique_ptr decoder_; + bool input_finished_; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/CMakeLists.txt b/speechx/speechx/frontend/audio/CMakeLists.txt index 35243b6e..2d20edf7 100644 --- a/speechx/speechx/frontend/audio/CMakeLists.txt +++ b/speechx/speechx/frontend/audio/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(frontend STATIC linear_spectrogram.cc audio_cache.cc feature_cache.cc + feature_pipeline.cc ) -target_link_libraries(frontend PUBLIC kaldi-matrix) \ No newline at end of file +target_link_libraries(frontend PUBLIC kaldi-matrix kaldi-feat-common) diff --git a/speechx/speechx/frontend/audio/audio_cache.cc b/speechx/speechx/frontend/audio/audio_cache.cc index 50aca4fb..e8af6668 100644 --- a/speechx/speechx/frontend/audio/audio_cache.cc +++ b/speechx/speechx/frontend/audio/audio_cache.cc @@ -41,7 +41,7 @@ void AudioCache::Accept(const VectorBase& waves) { ready_feed_condition_.wait(lock); } for (size_t idx = 0; idx < waves.Dim(); ++idx) { - int32 buffer_idx = (idx + offset_) % ring_buffer_.size(); + int32 buffer_idx = (idx + offset_ + size_) % ring_buffer_.size(); ring_buffer_[buffer_idx] = waves(idx); if (convert2PCM32_) ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx)); diff --git a/speechx/speechx/frontend/audio/audio_cache.h b/speechx/speechx/frontend/audio/audio_cache.h index adef1239..a681ef09 100644 --- a/speechx/speechx/frontend/audio/audio_cache.h +++ b/speechx/speechx/frontend/audio/audio_cache.h @@ -24,7 +24,7 @@ namespace ppspeech { class AudioCache : public FrontendInterface { public: explicit AudioCache(int buffer_size = 1000 * kint16max, - bool convert2PCM32 = false); + bool convert2PCM32 = true); virtual void Accept(const kaldi::VectorBase& waves); diff --git a/speechx/speechx/frontend/audio/feature_cache.cc b/speechx/speechx/frontend/audio/feature_cache.cc index 3f7f6502..b5768460 100644 --- a/speechx/speechx/frontend/audio/feature_cache.cc +++ b/speechx/speechx/frontend/audio/feature_cache.cc @@ -23,10 +23,13 @@ using std::vector; using kaldi::SubVector; using std::unique_ptr; -FeatureCache::FeatureCache(int max_size, +FeatureCache::FeatureCache(FeatureCacheOptions opts, unique_ptr base_extractor) { - max_size_ = max_size; + max_size_ = opts.max_size; + frame_chunk_stride_ = opts.frame_chunk_stride; + frame_chunk_size_ = opts.frame_chunk_size; base_extractor_ = std::move(base_extractor); + dim_ = base_extractor_->Dim(); } void FeatureCache::Accept(const kaldi::VectorBase& inputs) { @@ -44,13 +47,14 @@ bool FeatureCache::Read(kaldi::Vector* feats) { std::unique_lock lock(mutex_); while (cache_.empty() && base_extractor_->IsFinished() == false) { - ready_read_condition_.wait(lock); - BaseFloat elapsed = timer.Elapsed() * 1000; - // todo replace 1.0 with timeout_ - if (elapsed > 1.0) { + // todo refactor: wait + // ready_read_condition_.wait(lock); + int32 elapsed = static_cast(timer.Elapsed() * 1000); + // todo replace 1 with timeout_, 1 ms + if (elapsed > 1) { return false; } - usleep(1000); // sleep 1 ms + usleep(100); // sleep 0.1 ms } if (cache_.empty()) return false; feats->Resize(cache_.front().Dim()); @@ -63,25 +67,41 @@ bool FeatureCache::Read(kaldi::Vector* feats) { // read all data from base_feature_extractor_ into cache_ bool FeatureCache::Compute() { // compute and feed - Vector feature_chunk; - bool result = base_extractor_->Read(&feature_chunk); + Vector feature; + bool result = base_extractor_->Read(&feature); + if (result == false || feature.Dim() == 0) return false; + int32 joint_len = feature.Dim() + remained_feature_.Dim(); + int32 num_chunk = + ((joint_len / dim_) - frame_chunk_size_) / frame_chunk_stride_ + 1; - std::unique_lock lock(mutex_); - while (cache_.size() >= max_size_) { - ready_feed_condition_.wait(lock); - } + Vector joint_feature(joint_len); + joint_feature.Range(0, remained_feature_.Dim()) + .CopyFromVec(remained_feature_); + joint_feature.Range(remained_feature_.Dim(), feature.Dim()) + .CopyFromVec(feature); - // feed cache - if (feature_chunk.Dim() != 0) { + for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { + int32 start = chunk_idx * frame_chunk_stride_ * dim_; + Vector feature_chunk(frame_chunk_size_ * dim_); + SubVector tmp(joint_feature.Data() + start, + frame_chunk_size_ * dim_); + feature_chunk.CopyFromVec(tmp); + + std::unique_lock lock(mutex_); + while (cache_.size() >= max_size_) { + ready_feed_condition_.wait(lock); + } + + // feed cache cache_.push(feature_chunk); + ready_read_condition_.notify_one(); } - ready_read_condition_.notify_one(); + int32 remained_feature_len = + joint_len - num_chunk * frame_chunk_stride_ * dim_; + remained_feature_.Resize(remained_feature_len); + remained_feature_.CopyFromVec(joint_feature.Range( + frame_chunk_stride_ * num_chunk * dim_, remained_feature_len)); return result; } -void Reset() { - // std::lock_guard lock(mutex_); - return; -} - } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/feature_cache.h b/speechx/speechx/frontend/audio/feature_cache.h index 99961b5e..607f72c0 100644 --- a/speechx/speechx/frontend/audio/feature_cache.h +++ b/speechx/speechx/frontend/audio/feature_cache.h @@ -19,10 +19,18 @@ namespace ppspeech { +struct FeatureCacheOptions { + int32 max_size; + int32 frame_chunk_size; + int32 frame_chunk_stride; + FeatureCacheOptions() + : max_size(kint16max), frame_chunk_size(1), frame_chunk_stride(1) {} +}; + class FeatureCache : public FrontendInterface { public: explicit FeatureCache( - int32 max_size = kint16max, + FeatureCacheOptions opts, std::unique_ptr base_extractor = NULL); // Feed feats or waves @@ -32,12 +40,15 @@ class FeatureCache : public FrontendInterface { virtual bool Read(kaldi::Vector* feats); // feat dim - virtual size_t Dim() const { return base_extractor_->Dim(); } + virtual size_t Dim() const { return dim_; } virtual void SetFinished() { + // std::unique_lock lock(mutex_); base_extractor_->SetFinished(); + LOG(INFO) << "set finished"; // read the last chunk data Compute(); + // ready_feed_condition_.notify_one(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); } @@ -52,9 +63,13 @@ class FeatureCache : public FrontendInterface { private: bool Compute(); + int32 dim_; size_t max_size_; - std::unique_ptr base_extractor_; + int32 frame_chunk_size_; + int32 frame_chunk_stride_; + kaldi::Vector remained_feature_; + std::unique_ptr base_extractor_; std::mutex mutex_; std::queue> cache_; std::condition_variable ready_feed_condition_; diff --git a/speechx/speechx/frontend/audio/feature_pipeline.cc b/speechx/speechx/frontend/audio/feature_pipeline.cc new file mode 100644 index 00000000..86eca2e0 --- /dev/null +++ b/speechx/speechx/frontend/audio/feature_pipeline.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "frontend/audio/feature_pipeline.h" + +namespace ppspeech { + +using std::unique_ptr; + +FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) { + unique_ptr data_source( + new ppspeech::AudioCache(1000 * kint16max, opts.convert2PCM32)); + + unique_ptr linear_spectrogram( + new ppspeech::LinearSpectrogram(opts.linear_spectrogram_opts, + std::move(data_source))); + + unique_ptr cmvn( + new ppspeech::CMVN(opts.cmvn_file, std::move(linear_spectrogram))); + + base_extractor_.reset( + new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn))); +} + +} // ppspeech \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/feature_pipeline.h b/speechx/speechx/frontend/audio/feature_pipeline.h new file mode 100644 index 00000000..7bd6c84f --- /dev/null +++ b/speechx/speechx/frontend/audio/feature_pipeline.h @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// todo refactor later (SGoat) + +#pragma once + +#include "frontend/audio/audio_cache.h" +#include "frontend/audio/data_cache.h" +#include "frontend/audio/feature_cache.h" +#include "frontend/audio/frontend_itf.h" +#include "frontend/audio/linear_spectrogram.h" +#include "frontend/audio/normalizer.h" + +namespace ppspeech { + +struct FeaturePipelineOptions { + std::string cmvn_file; + bool convert2PCM32; + LinearSpectrogramOptions linear_spectrogram_opts; + FeatureCacheOptions feature_cache_opts; + FeaturePipelineOptions() + : cmvn_file(""), + convert2PCM32(false), + linear_spectrogram_opts(), + feature_cache_opts() {} +}; + +class FeaturePipeline : public FrontendInterface { + public: + explicit FeaturePipeline(const FeaturePipelineOptions& opts); + virtual void Accept(const kaldi::VectorBase& waves) { + base_extractor_->Accept(waves); + } + virtual bool Read(kaldi::Vector* feats) { + return base_extractor_->Read(feats); + } + virtual size_t Dim() const { return base_extractor_->Dim(); } + virtual void SetFinished() { base_extractor_->SetFinished(); } + virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + virtual void Reset() { base_extractor_->Reset(); } + + private: + std::unique_ptr base_extractor_; +}; +} \ No newline at end of file diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.cc b/speechx/speechx/frontend/audio/linear_spectrogram.cc index d6ae3d01..9ef5e766 100644 --- a/speechx/speechx/frontend/audio/linear_spectrogram.cc +++ b/speechx/speechx/frontend/audio/linear_spectrogram.cc @@ -52,16 +52,16 @@ bool LinearSpectrogram::Read(Vector* feats) { if (flag == false || input_feats.Dim() == 0) return false; int32 feat_len = input_feats.Dim(); - int32 left_len = reminded_wav_.Dim(); + int32 left_len = remained_wav_.Dim(); Vector waves(feat_len + left_len); - waves.Range(0, left_len).CopyFromVec(reminded_wav_); + waves.Range(0, left_len).CopyFromVec(remained_wav_); waves.Range(left_len, feat_len).CopyFromVec(input_feats); Compute(waves, feats); int32 frame_shift = opts_.frame_opts.WindowShift(); int32 num_frames = kaldi::NumFrames(waves.Dim(), opts_.frame_opts); int32 left_samples = waves.Dim() - frame_shift * num_frames; - reminded_wav_.Resize(left_samples); - reminded_wav_.CopyFromVec( + remained_wav_.Resize(left_samples); + remained_wav_.CopyFromVec( waves.Range(frame_shift * num_frames, left_samples)); return true; } diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.h b/speechx/speechx/frontend/audio/linear_spectrogram.h index 689ec2c4..2764b7cf 100644 --- a/speechx/speechx/frontend/audio/linear_spectrogram.h +++ b/speechx/speechx/frontend/audio/linear_spectrogram.h @@ -25,12 +25,12 @@ struct LinearSpectrogramOptions { kaldi::FrameExtractionOptions frame_opts; kaldi::BaseFloat streaming_chunk; // second - LinearSpectrogramOptions() : streaming_chunk(0.36), frame_opts() {} + LinearSpectrogramOptions() : streaming_chunk(0.1), frame_opts() {} void Register(kaldi::OptionsItf* opts) { opts->Register("streaming-chunk", &streaming_chunk, - "streaming chunk size, default: 0.36 sec"); + "streaming chunk size, default: 0.1 sec"); frame_opts.Register(opts); } }; @@ -48,7 +48,7 @@ class LinearSpectrogram : public FrontendInterface { virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual void Reset() { base_extractor_->Reset(); - reminded_wav_.Resize(0); + remained_wav_.Resize(0); } private: @@ -60,7 +60,7 @@ class LinearSpectrogram : public FrontendInterface { kaldi::BaseFloat hanning_window_energy_; LinearSpectrogramOptions opts_; std::unique_ptr base_extractor_; - kaldi::Vector reminded_wav_; + kaldi::Vector remained_wav_; int chunk_sample_size_; DISALLOW_COPY_AND_ASSIGN(LinearSpectrogram); }; diff --git a/speechx/speechx/nnet/decodable.cc b/speechx/speechx/nnet/decodable.cc index 3f5dadd2..465f64a9 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/speechx/speechx/nnet/decodable.cc @@ -78,7 +78,6 @@ bool Decodable::AdvanceChunk() { } int32 nnet_dim = 0; Vector inferences; - Matrix nnet_cache_tmp; nnet_->FeedForward(features, frontend_->Dim(), &inferences, &nnet_dim); nnet_cache_.Resize(inferences.Dim() / nnet_dim, nnet_dim); nnet_cache_.CopyRowsFromVec(inferences);