Merge branch 'PaddlePaddle:develop' into rhy

pull/2548/head
HuangLiangJie 2 years ago committed by GitHub
commit 872be9c8ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,13 +50,20 @@ repos:
entry: bash .pre-commit-hooks/clang-format.hook -i entry: bash .pre-commit-hooks/clang-format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.py)$
#- id: copyright_checker #- id: copyright_checker
# name: copyright_checker # name: copyright_checker
# entry: python .pre-commit-hooks/copyright-check.hook # entry: python .pre-commit-hooks/copyright-check.hook
# language: system # language: system
# files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$ # files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py)$
# exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$ # exclude: (?=third_party|pypinyin|speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin).*(\.cpp|\.cc|\.h|\.py)$
- id: cpplint
name: cpplint
description: Static code analysis of C/C++ files
language: python
files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$
exclude: (?=speechx/speechx/kaldi|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.py)$
entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent
- repo: https://github.com/asottile/reorder_python_imports - repo: https://github.com/asottile/reorder_python_imports
rev: v2.4.0 rev: v2.4.0
hooks: hooks:

@ -157,6 +157,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision
- 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV). - 🧩 *Cascaded models application*: as an extension of the typical traditional audio tasks, we combine the workflows of the aforementioned tasks with other fields like Natural language processing (NLP) and Computer Vision (CV).
### Recent Update ### Recent Update
- 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend.
- 👑 2022.10.11: Add [Wav2vec2ASR](./examples/librispeech/asr3), wav2vec2.0 fine-tuning for ASR on LibriSpeech. - 👑 2022.10.11: Add [Wav2vec2ASR](./examples/librispeech/asr3), wav2vec2.0 fine-tuning for ASR on LibriSpeech.
- 🔥 2022.09.26: Add Voice Cloning, TTS finetune, and ERNIE-SAT in [PaddleSpeech Web Demo](./demos/speech_web). - 🔥 2022.09.26: Add Voice Cloning, TTS finetune, and ERNIE-SAT in [PaddleSpeech Web Demo](./demos/speech_web).
- ⚡ 2022.09.09: Add AISHELL-3 Voice Cloning [example](./examples/aishell3/vc2) with ECAPA-TDNN speaker encoder. - ⚡ 2022.09.09: Add AISHELL-3 Voice Cloning [example](./examples/aishell3/vc2) with ECAPA-TDNN speaker encoder.
@ -923,8 +924,8 @@ You are warmly welcome to submit questions in [discussions](https://github.com/P
## Acknowledgement ## Acknowledgement
- Many thanks to [HighCWu](https://github.com/HighCWu) for adding [VITS-aishell3](./examples/aishell3/vits) and [VITS-VC](./examples/aishell3/vits-vc) examples. - Many thanks to [HighCWu](https://github.com/HighCWu) for adding [VITS-aishell3](./examples/aishell3/vits) and [VITS-VC](./examples/aishell3/vits-vc) examples.
- Many thanks to [david-95](https://github.com/david-95) improved TTS, fixed multi-punctuation bug, and contributed to multiple program and data. - Many thanks to [david-95](https://github.com/david-95) for fixing multi-punctuation bug、contributing to multiple program and data, and adding [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend.
- Many thanks to [BarryKCL](https://github.com/BarryKCL) improved TTS Chinses frontend based on [G2PW](https://github.com/GitYCC/g2pW). - Many thanks to [BarryKCL](https://github.com/BarryKCL) for improving TTS Chinses Frontend based on [G2PW](https://github.com/GitYCC/g2pW).
- Many thanks to [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) for years of attention, constructive advice and great help. - Many thanks to [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) for years of attention, constructive advice and great help.
- Many thanks to [mymagicpower](https://github.com/mymagicpower) for the Java implementation of ASR upon [short](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk) and [long](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk) audio files. - Many thanks to [mymagicpower](https://github.com/mymagicpower) for the Java implementation of ASR upon [short](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk) and [long](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk) audio files.
- Many thanks to [JiehangXie](https://github.com/JiehangXie)/[PaddleBoBo](https://github.com/JiehangXie/PaddleBoBo) for developing Virtual Uploader(VUP)/Virtual YouTuber(VTuber) with PaddleSpeech TTS function. - Many thanks to [JiehangXie](https://github.com/JiehangXie)/[PaddleBoBo](https://github.com/JiehangXie/PaddleBoBo) for developing Virtual Uploader(VUP)/Virtual YouTuber(VTuber) with PaddleSpeech TTS function.

@ -164,7 +164,8 @@
### 近期更新 ### 近期更新
- 👑 2022.10.11: 新增 [Wav2vec2ASR](./examples/librispeech/asr3), 在 LibriSpeech 上针对ASR任务对wav2vec2.0 的fine-tuning. - 🎉 2022.10.21: TTS 中文文本前端新增 [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) 功能。
- 👑 2022.10.11: 新增 [Wav2vec2ASR](./examples/librispeech/asr3), 在 LibriSpeech 上针对 ASR 任务对 wav2vec2.0 的 finetuning。
- 🔥 2022.09.26: 新增 Voice Cloning, TTS finetune 和 ERNIE-SAT 到 [PaddleSpeech 网页应用](./demos/speech_web)。 - 🔥 2022.09.26: 新增 Voice Cloning, TTS finetune 和 ERNIE-SAT 到 [PaddleSpeech 网页应用](./demos/speech_web)。
- ⚡ 2022.09.09: 新增基于 ECAPA-TDNN 声纹模型的 AISHELL-3 Voice Cloning [示例](./examples/aishell3/vc2)。 - ⚡ 2022.09.09: 新增基于 ECAPA-TDNN 声纹模型的 AISHELL-3 Voice Cloning [示例](./examples/aishell3/vc2)。
- ⚡ 2022.08.25: 发布 TTS [finetune](./examples/other/tts_finetune/tts3) 示例。 - ⚡ 2022.08.25: 发布 TTS [finetune](./examples/other/tts_finetune/tts3) 示例。
@ -928,7 +929,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
## 致谢 ## 致谢
- 非常感谢 [HighCWu](https://github.com/HighCWu) 新增 [VITS-aishell3](./examples/aishell3/vits) 和 [VITS-VC](./examples/aishell3/vits-vc) 代码示例。 - 非常感谢 [HighCWu](https://github.com/HighCWu) 新增 [VITS-aishell3](./examples/aishell3/vits) 和 [VITS-VC](./examples/aishell3/vits-vc) 代码示例。
- 非常感谢 [david-95](https://github.com/david-95) 修复句尾多标点符号出错的问题,贡献补充多条程序和数据。 - 非常感谢 [david-95](https://github.com/david-95) 修复 TTS 句尾多标点符号出错的问题,贡献补充多条程序和数据。为 TTS 中文文本前端新增 [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) 功能。
- 非常感谢 [BarryKCL](https://github.com/BarryKCL) 基于 [G2PW](https://github.com/GitYCC/g2pW) 对 TTS 中文文本前端的优化。 - 非常感谢 [BarryKCL](https://github.com/BarryKCL) 基于 [G2PW](https://github.com/GitYCC/g2pW) 对 TTS 中文文本前端的优化。
- 非常感谢 [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) 多年来的关注和建议,以及在诸多问题上的帮助。 - 非常感谢 [yeyupiaoling](https://github.com/yeyupiaoling)/[PPASR](https://github.com/yeyupiaoling/PPASR)/[PaddlePaddle-DeepSpeech](https://github.com/yeyupiaoling/PaddlePaddle-DeepSpeech)/[VoiceprintRecognition-PaddlePaddle](https://github.com/yeyupiaoling/VoiceprintRecognition-PaddlePaddle)/[AudioClassification-PaddlePaddle](https://github.com/yeyupiaoling/AudioClassification-PaddlePaddle) 多年来的关注和建议,以及在诸多问题上的帮助。
- 非常感谢 [mymagicpower](https://github.com/mymagicpower) 采用PaddleSpeech 对 ASR 的[短语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk)及[长语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk)进行 Java 实现。 - 非常感谢 [mymagicpower](https://github.com/mymagicpower) 采用PaddleSpeech 对 ASR 的[短语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_sdk)及[长语音](https://github.com/mymagicpower/AIAS/tree/main/3_audio_sdks/asr_long_audio_sdk)进行 Java 实现。

@ -53,3 +53,22 @@ Pretrain model from https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1
| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | -1 | 0.061884 | | conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | -1 | 0.061884 |
| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | -1 | 0.062056 | | conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | -1 | 0.062056 |
| conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention_rescoring | -1 | 0.052110 | | conformer | 32.52 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention_rescoring | -1 | 0.052110 |
## U2PP Steaming Pretrained Model
Pretrain model from https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_u2pp_wenetspeech_ckpt_1.3.0.model.tar.gz
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention | 16 | 0.057031 |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | 16 | 0.068826 |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | 16 | 0.069111 |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention_rescoring | 16 | 0.059213 |
| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size | CER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention | -1 | 0.049256 |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_greedy_search | -1 | 0.052086 |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | ctc_prefix_beam_search | -1 | 0.052267 |
| conformer | 122.88 M | conf/chunk_conformer.yaml | spec_aug | aishell1 | attention_rescoring | -1 | 0.047198 |

@ -42,6 +42,7 @@ for type in attention_rescoring; do
output_dir=${ckpt_prefix} output_dir=${ckpt_prefix}
mkdir -p ${output_dir} mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test_wav.py \ python3 -u ${BIN_DIR}/test_wav.py \
--debug True \
--ngpu ${ngpu} \ --ngpu ${ngpu} \
--config ${config_path} \ --config ${config_path} \
--decode_cfg ${decode_config_path} \ --decode_cfg ${decode_config_path} \

@ -16,6 +16,8 @@ import os
import sys import sys
from pathlib import Path from pathlib import Path
import distutils
import numpy as np
import paddle import paddle
import soundfile import soundfile
from yacs.config import CfgNode from yacs.config import CfgNode
@ -74,6 +76,8 @@ class U2Infer():
# fbank # fbank
feat = self.preprocessing(audio, **self.preprocess_args) feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}") logger.info(f"feat shape: {feat.shape}")
if self.args.debug:
np.savetxt("feat.transform.txt", feat)
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
@ -126,6 +130,11 @@ if __name__ == "__main__":
"--result_file", type=str, help="path of save the asr result") "--result_file", type=str, help="path of save the asr result")
parser.add_argument( parser.add_argument(
"--audio_file", type=str, help="path of the input audio file") "--audio_file", type=str, help="path of the input audio file")
parser.add_argument(
"--debug",
type=distutils.util.strtobool,
default=False,
help="for debug.")
args = parser.parse_args() args = parser.parse_args()
config = CfgNode(new_allowed=True) config = CfgNode(new_allowed=True)

@ -113,7 +113,7 @@ class ServerExecutor(BaseExecutor):
""" """
config = get_config(config_file) config = get_config(config_file)
if self.init(config): if self.init(config):
uvicorn.run(app, host=config.host, port=config.port, debug=True) uvicorn.run(app, host=config.host, port=config.port)
@cli_server_register( @cli_server_register(

@ -18,5 +18,6 @@ from . import exps
from . import frontend from . import frontend
from . import models from . import models
from . import modules from . import modules
from . import ssml
from . import training from . import training
from . import utils from . import utils

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
import os import os
import re
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from typing import Dict from typing import Dict
@ -33,6 +34,7 @@ from paddlespeech.t2s.frontend.mix_frontend import MixFrontend
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.modules.normalizer import ZScore from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.utils.dynamic_import import dynamic_import from paddlespeech.utils.dynamic_import import dynamic_import
# remove [W:onnxruntime: xxx] from ort # remove [W:onnxruntime: xxx] from ort
ort.set_default_logger_severity(3) ort.set_default_logger_severity(3)
@ -103,7 +105,8 @@ def get_sentences(text_file: Optional[os.PathLike], lang: str='zh'):
sentences = [] sentences = []
with open(text_file, 'rt') as f: with open(text_file, 'rt') as f:
for line in f: for line in f:
items = line.strip().split() if line.strip() != "":
items = re.split(r"\s+", line.strip(), 1)
utt_id = items[0] utt_id = items[0]
if lang == 'zh': if lang == 'zh':
sentence = "".join(items[1:]) sentence = "".join(items[1:])
@ -180,6 +183,15 @@ def run_frontend(frontend: object,
to_tensor: bool=True): to_tensor: bool=True):
outs = dict() outs = dict()
if lang == 'zh': if lang == 'zh':
input_ids = {}
if text.strip() != "" and re.match(r".*?<speak>.*?</speak>.*", text,
re.DOTALL):
input_ids = frontend.get_input_ids_ssml(
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids,
to_tensor=to_tensor)
else:
input_ids = frontend.get_input_ids( input_ids = frontend.get_input_ids(
text, text,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import re import re
from operator import itemgetter
from typing import Dict from typing import Dict
from typing import List from typing import List
@ -31,6 +32,7 @@ from paddlespeech.t2s.frontend.g2pw import G2PWOnnxConverter
from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon from paddlespeech.t2s.frontend.generate_lexicon import generate_lexicon
from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi from paddlespeech.t2s.frontend.tone_sandhi import ToneSandhi
from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer
from paddlespeech.t2s.ssml.xml_processor import MixTextProcessor
INITIALS = [ INITIALS = [
'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h', 'zh', 'ch', 'sh',
@ -81,6 +83,7 @@ class Frontend():
g2p_model="g2pW", g2p_model="g2pW",
phone_vocab_path=None, phone_vocab_path=None,
tone_vocab_path=None): tone_vocab_path=None):
self.mix_ssml_processor = MixTextProcessor()
self.tone_modifier = ToneSandhi() self.tone_modifier = ToneSandhi()
self.text_normalizer = TextNormalizer() self.text_normalizer = TextNormalizer()
self.punc = ":,;。?!“”‘’':,;.?!" self.punc = ":,;。?!“”‘’':,;.?!"
@ -281,6 +284,65 @@ class Frontend():
phones_list.append(merge_list) phones_list.append(merge_list)
return phones_list return phones_list
def _split_word_to_char(self, words):
res = []
for x in words:
res.append(x)
return res
# if using ssml, have pingyin specified, assign pinyin to words
def _g2p_assign(self,
words: List[str],
pinyin_spec: List[str],
merge_sentences: bool=True) -> List[List[str]]:
phones_list = []
initials = []
finals = []
words = self._split_word_to_char(words[0])
for pinyin, char in zip(pinyin_spec, words):
sub_initials = []
sub_finals = []
pinyin = pinyin.replace("u:", "v")
#self.pinyin2phone: is a dict with all pinyin mapped with sheng_mu yun_mu
if pinyin in self.pinyin2phone:
initial_final_list = self.pinyin2phone[pinyin].split(" ")
if len(initial_final_list) == 2:
sub_initials.append(initial_final_list[0])
sub_finals.append(initial_final_list[1])
elif len(initial_final_list) == 1:
sub_initials.append('')
sub_finals.append(initial_final_list[1])
else:
# If it's not pinyin (possibly punctuation) or no conversion is required
sub_initials.append(pinyin)
sub_finals.append(pinyin)
initials.append(sub_initials)
finals.append(sub_finals)
initials = sum(initials, [])
finals = sum(finals, [])
phones = []
for c, v in zip(initials, finals):
# NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii
if c and c not in self.punc:
phones.append(c)
if c and c in self.punc:
phones.append('sp')
if v and v not in self.punc:
phones.append(v)
phones_list.append(phones)
if merge_sentences:
merge_list = sum(phones_list, [])
# rm the last 'sp' to avoid the noise at the end
# cause in the training data, no 'sp' in the end
if merge_list[-1] == 'sp':
merge_list = merge_list[:-1]
phones_list = []
phones_list.append(merge_list)
return phones_list
def _merge_erhua(self, def _merge_erhua(self,
initials: List[str], initials: List[str],
finals: List[str], finals: List[str],
@ -396,6 +458,52 @@ class Frontend():
print("----------------------------") print("----------------------------")
return phonemes return phonemes
#@an added for ssml pinyin
def get_phonemes_ssml(self,
ssml_inputs: list,
merge_sentences: bool=True,
with_erhua: bool=True,
robot: bool=False,
print_info: bool=False) -> List[List[str]]:
all_phonemes = []
for word_pinyin_item in ssml_inputs:
phonemes = []
sentence, pinyin_spec = itemgetter(0, 1)(word_pinyin_item)
sentences = self.text_normalizer.normalize(sentence)
if len(pinyin_spec) == 0:
phonemes = self._g2p(
sentences,
merge_sentences=merge_sentences,
with_erhua=with_erhua)
else:
# phonemes should be pinyin_spec
phonemes = self._g2p_assign(
sentences, pinyin_spec, merge_sentences=merge_sentences)
all_phonemes = all_phonemes + phonemes
if robot:
new_phonemes = []
for sentence in all_phonemes:
new_sentence = []
for item in sentence:
# `er` only have tone `2`
if item[-1] in "12345" and item != "er2":
item = item[:-1] + "1"
new_sentence.append(item)
new_phonemes.append(new_sentence)
all_phonemes = new_phonemes
if print_info:
print("----------------------------")
print("text norm results:")
print(sentences)
print("----------------------------")
print("g2p results:")
print(all_phonemes[0])
print("----------------------------")
return [sum(all_phonemes, [])]
def get_input_ids(self, def get_input_ids(self,
sentence: str, sentence: str,
merge_sentences: bool=True, merge_sentences: bool=True,
@ -405,6 +513,7 @@ class Frontend():
add_blank: bool=False, add_blank: bool=False,
blank_token: str="<pad>", blank_token: str="<pad>",
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]: to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
phonemes = self.get_phonemes( phonemes = self.get_phonemes(
sentence, sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
@ -437,3 +546,49 @@ class Frontend():
if temp_phone_ids: if temp_phone_ids:
result["phone_ids"] = temp_phone_ids result["phone_ids"] = temp_phone_ids
return result return result
# @an added for ssml
def get_input_ids_ssml(
self,
sentence: str,
merge_sentences: bool=True,
get_tone_ids: bool=False,
robot: bool=False,
print_info: bool=False,
add_blank: bool=False,
blank_token: str="<pad>",
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
l_inputs = MixTextProcessor.get_pinyin_split(sentence)
phonemes = self.get_phonemes_ssml(
l_inputs,
merge_sentences=merge_sentences,
print_info=print_info,
robot=robot)
result = {}
phones = []
tones = []
temp_phone_ids = []
temp_tone_ids = []
for part_phonemes in phonemes:
phones, tones = self._get_phone_tone(
part_phonemes, get_tone_ids=get_tone_ids)
if add_blank:
phones = insert_after_character(phones, blank_token)
if tones:
tone_ids = self._t2id(tones)
if to_tensor:
tone_ids = paddle.to_tensor(tone_ids)
temp_tone_ids.append(tone_ids)
if phones:
phone_ids = self._p2id(phones)
# if use paddle.to_tensor() in onnxruntime, the first time will be too low
if to_tensor:
phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids)
if temp_tone_ids:
result["tone_ids"] = temp_tone_ids
if temp_phone_ids:
result["phone_ids"] = temp_phone_ids
return result

@ -0,0 +1,14 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .xml_processor import *

@ -0,0 +1,158 @@
# -*- coding: utf-8 -*-
import re
import xml.dom.minidom
import xml.parsers.expat
from xml.dom.minidom import Node
from xml.dom.minidom import parseString
'''
Note: xml 有5种特殊字符 &<>"'
其一采用<![CDATA[ ]]>特殊标签将包含特殊字符的字符串封装起来
例如
<TitleName><![CDATA["姓名"]]></TitleName>
其二使用XML转义序列表示这些特殊的字符这5个特殊字符所对应XML转义序列为
& &amp;
< &lt;
> &gt;
" &quot;
' &apos;
例如
<TitleName>&quot;姓名&quot;</TitleName>
'''
class MixTextProcessor():
def __repr__(self):
print("@an MixTextProcessor class")
def get_xml_content(self, mixstr):
'''返回字符串的 xml 内容'''
xmlptn = re.compile(r"<speak>.*?</speak>", re.M | re.S)
ctn = re.search(xmlptn, mixstr)
if ctn:
return ctn.group(0)
else:
return None
def get_content_split(self, mixstr):
''' 文本分解,顺序加了列表中,按非 xml 和 xml 分开,对应的字符串,带标点符号
不能去除空格因为 xml 中tag 属性带空格
'''
ctlist = []
# print("Testing:",mixstr[:20])
patn = re.compile(r'(.*\s*?)(<speak>.*?</speak>)(.*\s*)$', re.M | re.S)
mat = re.match(patn, mixstr)
if mat:
pre_xml = mat.group(1)
in_xml = mat.group(2)
after_xml = mat.group(3)
ctlist.append(pre_xml)
ctlist.append(in_xml)
ctlist.append(after_xml)
return ctlist
else:
ctlist.append(mixstr)
return ctlist
@classmethod
def get_pinyin_split(self, mixstr):
ctlist = []
patn = re.compile(r'(.*\s*?)(<speak>.*?</speak>)(.*\s*)$', re.M | re.S)
mat = re.match(patn, mixstr)
if mat:
pre_xml = mat.group(1)
in_xml = mat.group(2)
after_xml = mat.group(3)
ctlist.append([pre_xml, []])
dom = DomXml(in_xml)
pinyinlist = dom.get_pinyins_for_xml()
ctlist = ctlist + pinyinlist
ctlist.append([after_xml, []])
else:
ctlist.append([mixstr, []])
return ctlist
class DomXml():
def __init__(self, xmlstr):
self.tdom = parseString(xmlstr) #Document
self.root = self.tdom.documentElement #Element
self.rnode = self.tdom.childNodes #NodeList
def get_text(self):
'''返回 xml 内容的所有文本内容的列表'''
res = []
for x1 in self.rnode:
if x1.nodeType == Node.TEXT_NODE:
res.append(x1.value)
else:
for x2 in x1.childNodes:
if isinstance(x2, xml.dom.minidom.Text):
res.append(x2.data)
else:
for x3 in x2.childNodes:
if isinstance(x3, xml.dom.minidom.Text):
res.append(x3.data)
else:
print("len(nodes of x3):", len(x3.childNodes))
return res
def get_xmlchild_list(self):
'''返回 xml 内容的列表,包括所有文本内容(不带 tag)'''
res = []
for x1 in self.rnode:
if x1.nodeType == Node.TEXT_NODE:
res.append(x1.value)
else:
for x2 in x1.childNodes:
if isinstance(x2, xml.dom.minidom.Text):
res.append(x2.data)
else:
for x3 in x2.childNodes:
if isinstance(x3, xml.dom.minidom.Text):
res.append(x3.data)
else:
print("len(nodes of x3):", len(x3.childNodes))
print(res)
return res
def get_pinyins_for_xml(self):
'''返回 xml 内容,字符串和拼音的 list '''
res = []
for x1 in self.rnode:
if x1.nodeType == Node.TEXT_NODE:
t = re.sub(r"\s+", "", x1.value)
res.append([t, []])
else:
for x2 in x1.childNodes:
if isinstance(x2, xml.dom.minidom.Text):
t = re.sub(r"\s+", "", x2.data)
res.append([t, []])
else:
# print("x2",x2,x2.tagName)
if x2.hasAttribute('pinyin'):
pinyin_value = x2.getAttribute("pinyin")
pinyins = pinyin_value.split(" ")
for x3 in x2.childNodes:
# print('x3',x3)
if isinstance(x3, xml.dom.minidom.Text):
t = re.sub(r"\s+", "", x3.data)
res.append([t, pinyins])
else:
print("len(nodes of x3):", len(x3.childNodes))
return res
def get_all_tags(self, tag_name):
'''获取所有的 tag 及属性值'''
alltags = self.root.getElementsByTagName(tag_name)
for x in alltags:
if x.hasAttribute('pinyin'): # pinyin
print(x.tagName, 'pinyin',
x.getAttribute('pinyin'), x.firstChild.data)

@ -75,6 +75,7 @@ base = [
"braceexpand", "braceexpand",
"pyyaml", "pyyaml",
"pybind11", "pybind11",
"paddleslim==2.3.4",
] ]
server = ["fastapi", "uvicorn", "pattern_singleton", "websockets"] server = ["fastapi", "uvicorn", "pattern_singleton", "websockets"]

@ -0,0 +1,29 @@
# This file is used by clang-format to autoformat paddle source code
#
# The clang-format is part of llvm toolchain.
# It need to install llvm and clang to format source code style.
#
# The basic usage is,
# clang-format -i -style=file PATH/TO/SOURCE/CODE
#
# The -style=file implicit use ".clang-format" file located in one of
# parent directory.
# The -i means inplace change.
#
# The document of clang-format is
# http://clang.llvm.org/docs/ClangFormat.html
# http://clang.llvm.org/docs/ClangFormatStyleOptions.html
---
Language: Cpp
BasedOnStyle: Google
IndentWidth: 4
TabWidth: 4
ContinuationIndentWidth: 4
MaxEmptyLinesToKeep: 2
AccessModifierOffset: -2 # The private/protected/public has no indent in class
Standard: Cpp11
AllowAllParametersOfDeclarationOnNextLine: true
BinPackParameters: false
BinPackArguments: false
...

@ -1 +1,2 @@
tools/valgrind* tools/valgrind*
*log

@ -13,7 +13,6 @@ set(CMAKE_CXX_STANDARD 14)
set(speechx_cmake_dir ${PROJECT_SOURCE_DIR}/cmake) set(speechx_cmake_dir ${PROJECT_SOURCE_DIR}/cmake)
# Modules # Modules
list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir}/external)
list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir}) list(APPEND CMAKE_MODULE_PATH ${speechx_cmake_dir})
include(FetchContent) include(FetchContent)
include(ExternalProject) include(ExternalProject)
@ -32,9 +31,13 @@ SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall
############################################################################### ###############################################################################
# Option Configurations # Option Configurations
############################################################################### ###############################################################################
# option configurations
option(TEST_DEBUG "option for debug" OFF) option(TEST_DEBUG "option for debug" OFF)
option(USE_PROFILING "enable c++ profling" OFF)
option(USING_U2 "compile u2 model." ON)
option(USING_DS2 "compile with ds2 model." ON)
option(USING_GPU "u2 compute on GPU." OFF)
############################################################################### ###############################################################################
# Include third party # Include third party
@ -83,48 +86,65 @@ add_dependencies(openfst gflags glog)
# paddle lib # paddle lib
set(paddle_SOURCE_DIR ${fc_patch}/paddle-lib) include(paddleinference)
set(paddle_PREFIX_DIR ${fc_patch}/paddle-lib-prefix)
ExternalProject_Add(paddle
URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz # paddle core.so
URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873 find_package(Threads REQUIRED)
PREFIX ${paddle_PREFIX_DIR} find_package(PythonLibs REQUIRED)
SOURCE_DIR ${paddle_SOURCE_DIR} find_package(Python3 REQUIRED)
CONFIGURE_COMMAND "" find_package(pybind11 CONFIG)
BUILD_COMMAND ""
INSTALL_COMMAND "" message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}")
) message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}")
message(STATUS "Pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}, pybind11_LIBRARIES=${pybind11_LIBRARIES}, pybind11_DEFINITIONS=${pybind11_DEFINITIONS}")
set(PADDLE_LIB ${fc_patch}/paddle-lib)
include_directories("${PADDLE_LIB}/paddle/include") # paddle include and link option
set(PADDLE_LIB_THIRD_PARTY_PATH "${PADDLE_LIB}/third_party/install/") # -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include") execute_process(
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include") COMMAND python -c "\
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include") import os;\
import paddle;\
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib") include_dir=paddle.sysconfig.get_include();\
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib") paddle_dir=os.path.split(include_dir)[0];\
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib") libs_dir=os.path.join(paddle_dir, 'libs');\
link_directories("${PADDLE_LIB}/paddle/lib") fluid_dir=os.path.join(paddle_dir, 'fluid');\
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib") out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir]);\
out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out);\
##paddle with mkl "
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") OUTPUT_VARIABLE PADDLE_LINK_FLAGS
set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml") RESULT_VARIABLE SUCESS)
include_directories("${MATH_LIB_PATH}/include")
set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX} message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS})
${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX}) string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS)
set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
include_directories("${MKLDNN_PATH}/include") # paddle compile option
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0) # -I/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/include
set(EXTERNAL_LIB "-lrt -ldl -lpthread") execute_process(
COMMAND python -c "\
set(DEPS ${PADDLE_LIB}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX}) import paddle; \
set(DEPS ${DEPS} include_dir = paddle.sysconfig.get_include(); \
${MATH_LIB} ${MKLDNN_LIB} print(f\"-I{include_dir}\"); \
glog gflags protobuf xxhash cryptopp "
${EXTERNAL_LIB}) OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS)
message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS})
string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS)
# for LD_LIBRARY_PATH
# set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/)
execute_process(
COMMAND python -c " \
import os; \
import paddle; \
include_dir=paddle.sysconfig.get_include(); \
paddle_dir=os.path.split(include_dir)[0]; \
libs_dir=os.path.join(paddle_dir, 'libs'); \
fluid_dir=os.path.join(paddle_dir, 'fluid'); \
out=':'.join([libs_dir, fluid_dir]); print(out); \
"
OUTPUT_VARIABLE PADDLE_LIB_DIRS)
message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS})
############################################################################### ###############################################################################

@ -3,11 +3,14 @@
## Environment ## Environment
We develop under: We develop under:
* python - 3.7
* docker - `registry.baidubce.com/paddlepaddle/paddle:2.2.2-gpu-cuda10.2-cudnn7` * docker - `registry.baidubce.com/paddlepaddle/paddle:2.2.2-gpu-cuda10.2-cudnn7`
* os - Ubuntu 16.04.7 LTS * os - Ubuntu 16.04.7 LTS
* gcc/g++/gfortran - 8.2.0 * gcc/g++/gfortran - 8.2.0
* cmake - 3.16.0 * cmake - 3.16.0
> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build speechx.
> We make sure all things work fun under docker, and recommend using it to develop and deploy. > We make sure all things work fun under docker, and recommend using it to develop and deploy.
* [How to Install Docker](https://docs.docker.com/engine/install/) * [How to Install Docker](https://docs.docker.com/engine/install/)
@ -24,16 +27,23 @@ docker run --privileged --net=host --ipc=host -it --rm -v $PWD:/workspace --nam
* More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html). * More `Paddle` docker images you can see [here](https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/docker/linux-docker.html).
2. Create python environment.
2. Build `speechx` and `examples`. ```
bash tools/venv.sh
```
> Do not source venv. 2. Build `speechx` and `examples`.
For now we are using feature under `develop` branch of paddle, so we need to install `paddlepaddle` nightly build version.
For example:
``` ```
pushd /path/to/speechx source venv/bin/activate
python -m pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html
./build.sh ./build.sh
``` ```
3. Go to `examples` to have a fun. 3. Go to `examples` to have a fun.
More details please see `README.md` under `examples`. More details please see `README.md` under `examples`.

@ -1,4 +1,5 @@
#!/usr/bin/env bash #!/usr/bin/env bash
set -xe
# the build script had verified in the paddlepaddle docker image. # the build script had verified in the paddlepaddle docker image.
# please follow the instruction below to install PaddlePaddle image. # please follow the instruction below to install PaddlePaddle image.
@ -17,11 +18,6 @@ fi
#rm -rf build #rm -rf build
mkdir -p build mkdir -p build
cd build
cmake .. -DBOOST_ROOT:STRING=${boost_SOURCE_DIR} cmake -B build -DBOOST_ROOT:STRING=${boost_SOURCE_DIR}
#cmake .. cmake --build build -j
make -j
cd -

@ -1,12 +0,0 @@
include(FetchContent)
FetchContent_Declare(
gflags
URL https://github.com/gflags/gflags/archive/v2.2.1.zip
URL_HASH SHA256=4e44b69e709c826734dbbbd5208f61888a2faf63f239d73d8ba0011b2dccc97a
)
FetchContent_MakeAvailable(gflags)
# openfst need
include_directories(${gflags_BINARY_DIR}/include)

@ -0,0 +1,11 @@
include(FetchContent)
FetchContent_Declare(
gflags
URL https://github.com/gflags/gflags/archive/v2.2.2.zip
URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5
)
FetchContent_MakeAvailable(gflags)
# openfst need
include_directories(${gflags_BINARY_DIR}/include)

@ -1,8 +1,8 @@
include(FetchContent) include(FetchContent)
FetchContent_Declare( FetchContent_Declare(
gtest gtest
URL https://github.com/google/googletest/archive/release-1.10.0.zip URL https://github.com/google/googletest/archive/release-1.11.0.zip
URL_HASH SHA256=94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91 URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a
) )
FetchContent_MakeAvailable(gtest) FetchContent_MakeAvailable(gtest)

@ -1,7 +1,7 @@
include(FetchContent) include(FetchContent)
set(OpenBLAS_SOURCE_DIR ${fc_patch}/OpenBLAS-src) set(OpenBLAS_SOURCE_DIR ${fc_patch}/openblas-src)
set(OpenBLAS_PREFIX ${fc_patch}/OpenBLAS-prefix) set(OpenBLAS_PREFIX ${fc_patch}/openblas-prefix)
# ###################################################################################################################### # ######################################################################################################################
# OPENBLAS https://github.com/lattice/quda/blob/develop/CMakeLists.txt#L575 # OPENBLAS https://github.com/lattice/quda/blob/develop/CMakeLists.txt#L575
@ -43,6 +43,7 @@ ExternalProject_Add(
# https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition # https://cmake.org/cmake/help/latest/module/ExternalProject.html?highlight=externalproject_get_property#external-project-definition
ExternalProject_Get_Property(OPENBLAS INSTALL_DIR) ExternalProject_Get_Property(OPENBLAS INSTALL_DIR)
message(STATUS "OPENBLAS install dir: ${INSTALL_DIR}")
set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR}) set(OpenBLAS_INSTALL_PREFIX ${INSTALL_DIR})
add_library(openblas STATIC IMPORTED) add_library(openblas STATIC IMPORTED)
add_dependencies(openblas OPENBLAS) add_dependencies(openblas OPENBLAS)
@ -55,4 +56,6 @@ set_target_properties(openblas PROPERTIES IMPORTED_LOCATION ${OpenBLAS_INSTALL_P
# ${CMAKE_INSTALL_LIBDIR} lib # ${CMAKE_INSTALL_LIBDIR} lib
# ${CMAKE_INSTALL_INCLUDEDIR} include # ${CMAKE_INSTALL_INCLUDEDIR} include
link_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}) link_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR})
include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}) # include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR})
# fix for can not find `cblas.h`
include_directories(${OpenBLAS_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/openblas)

@ -0,0 +1,49 @@
set(paddle_SOURCE_DIR ${fc_patch}/paddle-lib)
set(paddle_PREFIX_DIR ${fc_patch}/paddle-lib-prefix)
include(FetchContent)
FetchContent_Declare(
paddle
URL https://paddle-inference-lib.bj.bcebos.com/2.2.2/cxx_c/Linux/CPU/gcc8.2_avx_mkl/paddle_inference.tgz
URL_HASH SHA256=7c6399e778c6554a929b5a39ba2175e702e115145e8fa690d2af974101d98873
PREFIX ${paddle_PREFIX_DIR}
SOURCE_DIR ${paddle_SOURCE_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
)
FetchContent_MakeAvailable(paddle)
set(PADDLE_LIB_THIRD_PARTY_PATH "${paddle_SOURCE_DIR}/third_party/install/")
include_directories("${paddle_SOURCE_DIR}/paddle/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include")
include_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib")
link_directories("${paddle_SOURCE_DIR}/paddle/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib")
link_directories("${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn/lib")
##paddle with mkl
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
set(MATH_LIB_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mklml")
include_directories("${MATH_LIB_PATH}/include")
set(MATH_LIB ${MATH_LIB_PATH}/lib/libmklml_intel${CMAKE_SHARED_LIBRARY_SUFFIX}
${MATH_LIB_PATH}/lib/libiomp5${CMAKE_SHARED_LIBRARY_SUFFIX})
set(MKLDNN_PATH "${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn")
include_directories("${MKLDNN_PATH}/include")
set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
set(EXTERNAL_LIB "-lrt -ldl -lpthread")
# global vars
set(DEPS ${paddle_SOURCE_DIR}/paddle/lib/libpaddle_inference${CMAKE_SHARED_LIBRARY_SUFFIX} CACHE INTERNAL "deps")
set(DEPS ${DEPS}
${MATH_LIB} ${MKLDNN_LIB}
glog gflags protobuf xxhash cryptopp
${EXTERNAL_LIB} CACHE INTERNAL "deps")
message(STATUS "Deps libraries: ${DEPS}")

@ -1,20 +1,42 @@
# Examples for SpeechX # Examples for SpeechX
> `u2pp_ol` is recommended.
* `u2pp_ol` - u2++ streaming asr test under `aishell-1` test dataset.
* `ds2_ol` - ds2 streaming test under `aishell-1` test dataset. * `ds2_ol` - ds2 streaming test under `aishell-1` test dataset.
## How to run ## How to run
`run.sh` is the entry point. ### Create env
Using `tools/evn.sh` under `speechx` to create python env.
```
bash tools/env.sh
```
Source env before play with example.
```
. venv/bin/activate
```
### Play with example
`run.sh` is the entry point for every example.
Example to play `ds2_ol`: Example to play `u2pp_ol`:
``` ```
pushd ds2_ol/aishell pushd u2pp_ol/wenetspeech
bash run.sh bash run.sh --stop_stage 4
``` ```
## Display Model with [Netron](https://github.com/lutzroeder/netron) ## Display Model with [Netron](https://github.com/lutzroeder/netron)
If you have a model, we can using this commnd to show model graph.
For example:
``` ```
pip install netron pip install netron
netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host 10.21.55.20 netron exp/deepspeech2_online/checkpoints/avg_1.jit.pdmodel --port 8022 --host 10.21.55.20

@ -1,8 +1,9 @@
# Codelab # Codelab
## introduction > The below is for developing and offline testing.
> Do not run it only if you know what it is.
> The below is for developing and offline testing. Do not run it only if you know what it is.
* nnet * nnet
* feat * feat
* decoder * decoder
* u2

@ -69,7 +69,7 @@ compute_linear_spectrogram_main \
echo "compute linear spectrogram feature." echo "compute linear spectrogram feature."
# run ctc beam search decoder as streaming # run ctc beam search decoder as streaming
ctc_prefix_beam_search_decoder_main \ ctc_beam_search_decoder_main \
--result_wspecifier=ark,t:$exp_dir/result.txt \ --result_wspecifier=ark,t:$exp_dir/result.txt \
--feature_rspecifier=ark:$feat_wspecifier \ --feature_rspecifier=ark:$feat_wspecifier \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \

@ -1,12 +1,12 @@
# This contains the locations of binarys build required for running the examples. # This contains the locations of binarys build required for running the examples.
SPEECHX_ROOT=$PWD/../../../ SPEECHX_ROOT=$PWD/../../../
SPEECHX_EXAMPLES=$SPEECHX_ROOT/build/examples SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } [ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C export LC_AL=C

@ -42,8 +42,8 @@ mkdir -p $exp_dir
export GLOG_logtostderr=1 export GLOG_logtostderr=1
cmvn_json2kaldi_main \ cmvn_json2kaldi_main \
--json_file $model_dir/data/mean_std.json \ --json_file=$model_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \ --cmvn_write_path=$exp_dir/cmvn.ark \
--binary=false --binary=false
echo "convert json cmvn to kaldi ark." echo "convert json cmvn to kaldi ark."
@ -54,4 +54,10 @@ compute_linear_spectrogram_main \
--cmvn_file=$exp_dir/cmvn.ark --cmvn_file=$exp_dir/cmvn.ark
echo "compute linear spectrogram feature." echo "compute linear spectrogram feature."
compute_fbank_main \
--num_bins=161 \
--wav_rspecifier=scp:$data_dir/wav.scp \
--feature_wspecifier=ark,t:$exp_dir/fbank.ark \
--cmvn_file=$exp_dir/cmvn.ark
echo "compute fbank feature."

@ -6,7 +6,7 @@ SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_EXAMPLES ] || { echo "Error: 'build/examples' directory not found. please ensure that the project build successfully"; } [ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C export LC_AL=C

@ -0,0 +1 @@
# u2/u2pp Streaming Test

@ -0,0 +1,22 @@
#!/bin/bash
set +x
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=$data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
ctc_prefix_beam_search_decoder_main \
--model_path=$model_dir/export.jit \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--vocab_path=$model_dir/unit.txt \
--feature_rspecifier=ark,t:$exp/fbank.ark \
--result_wspecifier=ark,t:$exp/result.ark
echo "u2 ctc prefix beam search decode."

@ -0,0 +1,27 @@
#!/bin/bash
set -x
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
cmvn_json2kaldi_main \
--json_file $model_dir/mean_std.json \
--cmvn_write_path $exp/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
compute_fbank_main \
--num_bins 80 \
--wav_rspecifier=scp:$data/wav.scp \
--cmvn_file=$exp/cmvn.ark \
--feature_wspecifier=ark,t:$exp/fbank.ark
echo "compute fbank feature."

@ -0,0 +1,23 @@
#!/bin/bash
set -x
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
u2_nnet_main \
--model_path=$model_dir/export.jit \
--feature_rspecifier=ark,t:$exp/fbank.ark \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--acoustic_scale=1.0 \
--nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \
--nnet_prob_wspecifier=ark,t:$exp/logprobs.ark
echo "u2 nnet decode."

@ -0,0 +1,22 @@
#!/bin/bash
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
u2_recognizer_main \
--use_fbank=true \
--num_bins=80 \
--cmvn_file=$exp/cmvn.ark \
--model_path=$model_dir/export.jit \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--vocab_path=$model_dir/unit.txt \
--wav_rspecifier=scp:$data/wav.scp \
--result_wspecifier=ark,t:$exp/result.ark

@ -0,0 +1,18 @@
# This contains the locations of binarys build required for running the examples.
unset GREP_OPTIONS
SPEECHX_ROOT=$PWD/../../../
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio:$SPEECHX_BUILD/recognizer
PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);")
export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH

@ -0,0 +1,43 @@
#!/bin/bash
set -x
set -e
. path.sh
# 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
# 2. download model
if [ ! -f data/model/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz ]; then
mkdir -p data/model
pushd data/model
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
popd
fi
# produce wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p data
pushd data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
./local/feat.sh
./local/nnet.sh
./local/decode.sh

@ -1,5 +1,5 @@
#!/bin/bash #!/bin/bash
set +x set -x
set -e set -e
. path.sh . path.sh
@ -11,7 +11,7 @@ stop_stage=100
. utils/parse_options.sh . utils/parse_options.sh
# 1. compile # 1. compile
if [ ! -d ${SPEECHX_EXAMPLES} ]; then if [ ! -d ${SPEECHX_BUILD} ]; then
pushd ${SPEECHX_ROOT} pushd ${SPEECHX_ROOT}
bash build.sh bash build.sh
popd popd
@ -84,7 +84,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# recognizer # recognizer
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wolm.log \
ctc_prefix_beam_search_decoder_main \ ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
@ -103,7 +103,7 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# decode with lm # decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.lm.log \
ctc_prefix_beam_search_decoder_main \ ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \
@ -135,7 +135,7 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# TLG decoder # TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.wfst.log \
tlg_decoder_main \ ctc_tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/feat.scp \
--model_path=$model_dir/avg_1.jit.pdmodel \ --model_path=$model_dir/avg_1.jit.pdmodel \
--param_path=$model_dir/avg_1.jit.pdiparams \ --param_path=$model_dir/avg_1.jit.pdiparams \

@ -84,7 +84,7 @@ fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# recognizer # recognizer
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wolm.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wolm.log \
ctc_prefix_beam_search_decoder_main \ ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \ --model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \ --param_path=$model_dir/avg_5.jit.pdiparams \
@ -102,7 +102,7 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# decode with lm # decode with lm
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.lm.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.lm.log \
ctc_prefix_beam_search_decoder_main \ ctc_beam_search_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \ --model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \ --param_path=$model_dir/avg_5.jit.pdiparams \
@ -133,7 +133,7 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# TLG decoder # TLG decoder
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \ utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recog.fbank.wfst.log \
tlg_decoder_main \ ctc_tlg_decoder_main \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \ --feature_rspecifier=scp:$data/split${nj}/JOB/fbank_feat.scp \
--model_path=$model_dir/avg_5.jit.pdmodel \ --model_path=$model_dir/avg_5.jit.pdmodel \
--param_path=$model_dir/avg_5.jit.pdiparams \ --param_path=$model_dir/avg_5.jit.pdiparams \

@ -0,0 +1,5 @@
# U2/U2++ Streaming ASR
## Examples
* `wenetspeech` - Streaming Decoding with wenetspeech u2/u2++ model. Using aishell test data for testing.

@ -0,0 +1,28 @@
# u2/u2pp Streaming ASR
## Testing with Aishell Test Data
## Download wav and model
```
run.sh --stop_stage 0
```
### compute feature
```
./run.sh --stage 1 --stop_stage 1
```
### decoding using feature
```
./run.sh --stage 2 --stop_stage 2
```
### decoding using wav
```
./run.sh --stage 3 --stop_stage 3
```

@ -0,0 +1,71 @@
#!/bin/bash
# To be run from one directory above this script.
. ./path.sh
nj=40
text=data/local/lm/text
lexicon=data/local/dict/lexicon.txt
for f in "$text" "$lexicon"; do
[ ! -f $x ] && echo "$0: No such file $f" && exit 1;
done
# Check SRILM tools
if ! which ngram-count > /dev/null; then
echo "srilm tools are not found, please download it and install it from: "
echo "http://www.speech.sri.com/projects/srilm/download.html"
echo "Then add the tools to your PATH"
exit 1
fi
# This script takes no arguments. It assumes you have already run
# aishell_data_prep.sh.
# It takes as input the files
# data/local/lm/text
# data/local/dict/lexicon.txt
dir=data/local/lm
mkdir -p $dir
cleantext=$dir/text.no_oov
# oov to <SPOKEN_NOISE>
# lexicon line: word char0 ... charn
# text line: utt word0 ... wordn -> line: <SPOKEN_NOISE> word0 ... wordn
text_dir=$(dirname $text)
split_name=$(basename $text)
./local/split_data.sh $text_dir $text $split_name $nj
utils/run.pl JOB=1:$nj $text_dir/split${nj}/JOB/${split_name}.no_oov.log \
cat ${text_dir}/split${nj}/JOB/${split_name} \| awk -v lex=$lexicon 'BEGIN{while((getline<lex) >0){ seen[$1]=1; } }
{for(n=1; n<=NF;n++) { if (seen[$n]) { printf("%s ", $n); } else {printf("<SPOKEN_NOISE> ");} } printf("\n");}' \
\> ${text_dir}/split${nj}/JOB/${split_name}.no_oov || exit 1;
cat ${text_dir}/split${nj}/*/${split_name}.no_oov > $cleantext
# compute word counts, sort in descending order
# line: count word
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | sort --parallel=`nproc` | uniq -c | \
sort --parallel=`nproc` -nr > $dir/word.counts || exit 1;
# Get counts from acoustic training transcripts, and add one-count
# for each word in the lexicon (but not silence, we don't want it
# in the LM-- we'll add it optionally later).
cat $cleantext | awk '{for(n=2;n<=NF;n++) print $n; }' | \
cat - <(grep -w -v '!SIL' $lexicon | awk '{print $1}') | \
sort --parallel=`nproc` | uniq -c | sort --parallel=`nproc` -nr > $dir/unigram.counts || exit 1;
# word with <s> </s>
cat $dir/unigram.counts | awk '{print $2}' | cat - <(echo "<s>"; echo "</s>" ) > $dir/wordlist
# hold out to compute ppl
heldout_sent=10000 # Don't change this if you want result to be comparable with kaldi_lm results
mkdir -p $dir
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
head -$heldout_sent > $dir/heldout
cat $cleantext | awk '{for(n=2;n<=NF;n++){ printf $n; if(n<NF) printf " "; else print ""; }}' | \
tail -n +$heldout_sent > $dir/train
ngram-count -text $dir/train -order 3 -limit-vocab -vocab $dir/wordlist -unk \
-map-unk "<UNK>" -kndiscount -interpolate -lm $dir/lm.arpa
ngram -lm $dir/lm.arpa -ppl $dir/heldout

@ -0,0 +1,25 @@
#!/bin/bash
set -e
. path.sh
data=data
exp=exp
nj=20
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/decoder.fbank.wolm.log \
ctc_prefix_beam_search_decoder_main \
--model_path=$model_dir/export.jit \
--vocab_path=$model_dir/unit.txt \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--feature_rspecifier=scp:$data/split${nj}/JOB/fbank.scp \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_decode.ark
cat $data/split${nj}/*/result_decode.ark > $exp/${label_file}
utils/compute-wer.py --char=1 --v=1 $text $exp/${label_file} > $exp/${wer}
tail -n 7 $exp/${wer}

@ -0,0 +1,31 @@
#!/bin/bash
set -e
. path.sh
data=data
exp=exp
nj=20
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
aishell_wav_scp=aishell_test.scp
cmvn_json2kaldi_main \
--json_file $model_dir/mean_std.json \
--cmvn_write_path $exp/cmvn.ark \
--binary=false
echo "convert json cmvn to kaldi ark."
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/feat.log \
compute_fbank_main \
--num_bins 80 \
--cmvn_file=$exp/cmvn.ark \
--streaming_chunk=36 \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--feature_wspecifier=ark,scp:$data/split${nj}/JOB/fbank.ark,$data/split${nj}/JOB/fbank.scp
echo "compute fbank feature."

@ -0,0 +1,23 @@
#!/bin/bash
set -x
set -e
. path.sh
data=data
exp=exp
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
u2_nnet_main \
--model_path=$model_dir/export.jit \
--feature_rspecifier=ark,t:$exp/fbank.ark \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--acoustic_scale=1.0 \
--nnet_encoder_outs_wspecifier=ark,t:$exp/encoder_outs.ark \
--nnet_prob_wspecifier=ark,t:$exp/logprobs.ark
echo "u2 nnet decode."

@ -0,0 +1,37 @@
#!/bin/bash
set -e
. path.sh
data=data
exp=exp
nj=20
mkdir -p $exp
ckpt_dir=./data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
aishell_wav_scp=aishell_test.scp
text=$data/test/text
./local/split_data.sh $data $data/$aishell_wav_scp $aishell_wav_scp $nj
utils/run.pl JOB=1:$nj $data/split${nj}/JOB/recognizer.log \
u2_recognizer_main \
--use_fbank=true \
--num_bins=80 \
--cmvn_file=$exp/cmvn.ark \
--model_path=$model_dir/export.jit \
--vocab_path=$model_dir/unit.txt \
--nnet_decoder_chunk=16 \
--receptive_field_length=7 \
--subsampling_rate=4 \
--wav_rspecifier=scp:$data/split${nj}/JOB/${aishell_wav_scp} \
--result_wspecifier=ark,t:$data/split${nj}/JOB/result_recognizer.ark
cat $data/split${nj}/*/result_recognizer.ark > $exp/aishell_recognizer
utils/compute-wer.py --char=1 --v=1 $text $exp/aishell_recognizer > $exp/aishell.recognizer.err
echo "recognizer test have finished!!!"
echo "please checkout in $exp/aishell.recognizer.err"
tail -n 7 $exp/aishell.recognizer.err

@ -0,0 +1,30 @@
#!/usr/bin/env bash
set -eo pipefail
data=$1
scp=$2
split_name=$3
numsplit=$4
# save in $data/split{n}
# $scp to split
#
if [[ ! $numsplit -gt 0 ]]; then
echo "$0: Invalid num-split argument";
exit 1;
fi
directories=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n; done)
scp_splits=$(for n in `seq $numsplit`; do echo $data/split${numsplit}/$n/${split_name}; done)
# if this mkdir fails due to argument-list being too long, iterate.
if ! mkdir -p $directories >&/dev/null; then
for n in `seq $numsplit`; do
mkdir -p $data/split${numsplit}/$n
done
fi
echo "utils/split_scp.pl $scp $scp_splits"
utils/split_scp.pl $scp $scp_splits

@ -0,0 +1,18 @@
# This contains the locations of binarys build required for running the examples.
unset GREP_OPTIONS
SPEECHX_ROOT=$PWD/../../../
SPEECHX_BUILD=$SPEECHX_ROOT/build/speechx
SPEECHX_TOOLS=$SPEECHX_ROOT/tools
TOOLS_BIN=$SPEECHX_TOOLS/valgrind/install/bin
[ -d $SPEECHX_BUILD ] || { echo "Error: 'build/speechx' directory not found. please ensure that the project build successfully"; }
export LC_AL=C
export PATH=$PATH:$TOOLS_BIN:$SPEECHX_BUILD/nnet:$SPEECHX_BUILD/decoder:$SPEECHX_BUILD/frontend/audio:$SPEECHX_BUILD/recognizer
PADDLE_LIB_PATH=$(python -c "import os; import paddle; include_dir=paddle.sysconfig.get_include(); paddle_dir=os.path.split(include_dir)[0]; libs_dir=os.path.join(paddle_dir, 'libs'); fluid_dir=os.path.join(paddle_dir, 'fluid'); out=':'.join([libs_dir, fluid_dir]); print(out);")
export LD_LIBRARY_PATH=$PADDLE_LIB_PATH:$LD_LIBRARY_PATH

@ -0,0 +1,76 @@
#!/bin/bash
set +x
set -e
. path.sh
nj=40
stage=0
stop_stage=5
. utils/parse_options.sh
# input
data=data
exp=exp
mkdir -p $exp $data
# 1. compile
if [ ! -d ${SPEECHX_BUILD} ]; then
pushd ${SPEECHX_ROOT}
bash build.sh
popd
fi
ckpt_dir=$data/model
model_dir=$ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model/
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ];then
# download model
if [ ! -f $ckpt_dir/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz ]; then
mkdir -p $ckpt_dir
pushd $ckpt_dir
wget -c https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/static/asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
tar xzfv asr1_chunk_conformer_u2pp_wenetspeech_static_1.1.0.model.tar.gz
popd
fi
# test wav scp
if [ ! -f data/wav.scp ]; then
mkdir -p $data
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav
echo "utt1 " $PWD/zh.wav > wav.scp
popd
fi
# aishell wav scp
if [ ! -d $data/test ]; then
pushd $data
wget -c https://paddlespeech.bj.bcebos.com/s2t/paddle_asr_online/aishell_test.zip
unzip aishell_test.zip
popd
realpath $data/test/*/*.wav > $data/wavlist
awk -F '/' '{ print $(NF) }' $data/wavlist | awk -F '.' '{ print $1 }' > $data/utt_id
paste $data/utt_id $data/wavlist > $data/$aishell_wav_scp
fi
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
./local/feat.sh
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
./local/decode.sh
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
./loca/recognizer.sh
fi

@ -32,6 +32,12 @@ ${CMAKE_CURRENT_SOURCE_DIR}/decoder
) )
add_subdirectory(decoder) add_subdirectory(decoder)
include_directories(
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/recognizer
)
add_subdirectory(recognizer)
include_directories( include_directories(
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/protocol ${CMAKE_CURRENT_SOURCE_DIR}/protocol

@ -14,47 +14,47 @@
#pragma once #pragma once
#include "kaldi/base/kaldi-types.h"
#include <limits> #include <limits>
#include "kaldi/base/kaldi-types.h"
typedef float BaseFloat; typedef float BaseFloat;
typedef double double64; typedef double double64;
typedef signed char int8; typedef signed char int8;
typedef short int16; typedef short int16; // NOLINT
typedef int int32; typedef int int32; // NOLINT
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef long int64; typedef long int64; // NOLINT
#else #else
typedef long long int64; typedef long long int64; // NOLINT
#endif #endif
typedef unsigned char uint8; typedef unsigned char uint8; // NOLINT
typedef unsigned short uint16; typedef unsigned short uint16; // NOLINT
typedef unsigned int uint32; typedef unsigned int uint32; // NOLINT
#if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD) #if defined(__LP64__) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
typedef unsigned long uint64; typedef unsigned long uint64; // NOLINT
#else #else
typedef unsigned long long uint64; typedef unsigned long long uint64; // NOLINT
#endif #endif
typedef signed int char32; typedef signed int char32;
const uint8 kuint8max = ((uint8)0xFF); const uint8 kuint8max = static_cast<uint8>(0xFF);
const uint16 kuint16max = ((uint16)0xFFFF); const uint16 kuint16max = static_cast<uint16>(0xFFFF);
const uint32 kuint32max = ((uint32)0xFFFFFFFF); const uint32 kuint32max = static_cast<uint32>(0xFFFFFFFF);
const uint64 kuint64max = ((uint64)(0xFFFFFFFFFFFFFFFFLL)); const uint64 kuint64max = static_cast<uint64>(0xFFFFFFFFFFFFFFFFLL);
const int8 kint8min = ((int8)0x80); const int8 kint8min = static_cast<int8>(0x80);
const int8 kint8max = ((int8)0x7F); const int8 kint8max = static_cast<int8>(0x7F);
const int16 kint16min = ((int16)0x8000); const int16 kint16min = static_cast<int16>(0x8000);
const int16 kint16max = ((int16)0x7FFF); const int16 kint16max = static_cast<int16>(0x7FFF);
const int32 kint32min = ((int32)0x80000000); const int32 kint32min = static_cast<int32>(0x80000000);
const int32 kint32max = ((int32)0x7FFFFFFF); const int32 kint32max = static_cast<int32>(0x7FFFFFFF);
const int64 kint64min = ((int64)(0x8000000000000000LL)); const int64 kint64min = static_cast<int64>(0x8000000000000000LL);
const int64 kint64max = ((int64)(0x7FFFFFFFFFFFFFFFLL)); const int64 kint64max = static_cast<int64>(0x7FFFFFFFFFFFFFFFLL);
const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max(); const BaseFloat kBaseFloatMax = std::numeric_limits<BaseFloat>::max();
const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min(); const BaseFloat kBaseFloatMin = std::numeric_limits<BaseFloat>::min();

@ -14,21 +14,30 @@
#pragma once #pragma once
#include <algorithm>
#include <cassert>
#include <cmath>
#include <condition_variable> #include <condition_variable>
#include <cstring>
#include <deque> #include <deque>
#include <fstream> #include <fstream>
#include <iomanip>
#include <iostream> #include <iostream>
#include <istream> #include <istream>
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <numeric>
#include <ostream> #include <ostream>
#include <queue> #include <queue>
#include <set> #include <set>
#include <sstream> #include <sstream>
#include <stack> #include <stack>
#include <stdexcept>
#include <string> #include <string>
#include <thread> #include <thread>
#include <tuple>
#include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
@ -38,3 +47,5 @@
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "base/macros.h" #include "base/macros.h"
#include "utils/file_utils.h"
#include "utils/math.h"

@ -14,6 +14,9 @@
#pragma once #pragma once
#include <limits>
#include <string>
namespace ppspeech { namespace ppspeech {
#ifndef DISALLOW_COPY_AND_ASSIGN #ifndef DISALLOW_COPY_AND_ASSIGN
@ -22,4 +25,7 @@ namespace ppspeech {
void operator=(const TypeName&) = delete void operator=(const TypeName&) = delete
#endif #endif
} // namespace pp_speech // kSpaceSymbol in UTF-8 is: ▁
const char kSpaceSymbo[] = "\xe2\x96\x81";
} // namespace ppspeech

@ -35,7 +35,7 @@
class ThreadPool { class ThreadPool {
public: public:
ThreadPool(size_t); explicit ThreadPool(size_t);
template <class F, class... Args> template <class F, class... Args>
auto enqueue(F&& f, Args&&... args) auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>; -> std::future<typename std::result_of<F(Args...)>::type>;

@ -17,7 +17,7 @@
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
// Initialize Googles logging library. // Initialize Googles logging library.
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1; FLAGS_logtostderr = 1;
LOG(INFO) << "Found " << 10 << " cookies"; LOG(INFO) << "Found " << 10 << " cookies";

@ -21,6 +21,7 @@
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <thread> #include <thread>
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "paddle_inference_api.h" #include "paddle_inference_api.h"
@ -63,8 +64,8 @@ void model_forward_test() {
; ;
std::string model_graph = FLAGS_model_path; std::string model_graph = FLAGS_model_path;
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
CHECK(model_graph != ""); CHECK_NE(model_graph, "");
CHECK(model_params != ""); CHECK_NE(model_params, "");
cout << "model path: " << model_graph << endl; cout << "model path: " << model_graph << endl;
cout << "model param path : " << model_params << endl; cout << "model param path : " << model_params << endl;
@ -195,8 +196,11 @@ void model_forward_test() {
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
model_forward_test(); model_forward_test();
return 0; return 0;

@ -1,25 +1,55 @@
project(decoder)
include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders}) include_directories(${CMAKE_CURRENT_SOURCE_DIR/ctc_decoders})
add_library(decoder STATIC
ctc_beam_search_decoder.cc set(srcs)
if (USING_DS2)
list(APPEND srcs
ctc_decoders/decoder_utils.cpp ctc_decoders/decoder_utils.cpp
ctc_decoders/path_trie.cpp ctc_decoders/path_trie.cpp
ctc_decoders/scorer.cpp ctc_decoders/scorer.cpp
ctc_beam_search_decoder.cc
ctc_tlg_decoder.cc ctc_tlg_decoder.cc
recognizer.cc
) )
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder) endif()
set(BINS if (USING_U2)
ctc_prefix_beam_search_decoder_main list(APPEND srcs
ctc_prefix_beam_search_decoder.cc
)
endif()
add_library(decoder STATIC ${srcs})
target_link_libraries(decoder PUBLIC kenlm utils fst frontend nnet kaldi-decoder absl::strings)
# test
if (USING_DS2)
set(BINS
ctc_beam_search_decoder_main
nnet_logprob_decoder_main nnet_logprob_decoder_main
recognizer_main ctc_tlg_decoder_main
tlg_decoder_main )
)
foreach(bin_name IN LISTS BINS) foreach(bin_name IN LISTS BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS}) target_link_libraries(${bin_name} PUBLIC nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util ${DEPS})
endforeach() endforeach()
endif()
if (USING_U2)
set(TEST_BINS
ctc_prefix_beam_search_decoder_main
)
foreach(bin_name IN LISTS TEST_BINS)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util)
target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS})
target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR})
target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS})
endforeach()
endif()

@ -1,3 +1,4 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
@ -12,10 +13,36 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "base/basic_types.h" #pragma once
#include "base/common.h"
struct DecoderResult { struct DecoderResult {
BaseFloat acoustic_score; BaseFloat acoustic_score;
std::vector<int32> words_idx; std::vector<int32> words_idx;
std::vector<pair<int32, int32>> time_stamp; std::vector<std::pair<int32, int32>> time_stamp;
};
namespace ppspeech {
struct WordPiece {
std::string word;
int start = -1;
int end = -1;
WordPiece(std::string word, int start, int end)
: word(std::move(word)), start(start), end(end) {}
}; };
struct DecodeResult {
float score = -kBaseFloatMax;
std::string sentence;
std::vector<WordPiece> word_pieces;
static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) {
return a.score > b.score;
}
};
} // namespace ppspeech

@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "decoder/ctc_beam_search_decoder.h" #include "decoder/ctc_beam_search_decoder.h"
#include "base/basic_types.h" #include "base/common.h"
#include "decoder/ctc_decoders/decoder_utils.h" #include "decoder/ctc_decoders/decoder_utils.h"
#include "utils/file_utils.h" #include "utils/file_utils.h"
@ -24,12 +25,7 @@ using std::vector;
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>; using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts) CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
: opts_(opts), : opts_(opts), init_ext_scorer_(nullptr), space_id_(-1), root_(nullptr) {
init_ext_scorer_(nullptr),
blank_id_(-1),
space_id_(-1),
num_frame_decoded_(0),
root_(nullptr) {
LOG(INFO) << "dict path: " << opts_.dict_file; LOG(INFO) << "dict path: " << opts_.dict_file;
if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) { if (!ReadFileToVector(opts_.dict_file, &vocabulary_)) {
LOG(INFO) << "load the dict failed"; LOG(INFO) << "load the dict failed";
@ -43,12 +39,12 @@ CTCBeamSearch::CTCBeamSearch(const CTCBeamSearchOptions& opts)
opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_); opts_.alpha, opts_.beta, opts_.lm_path, vocabulary_);
} }
blank_id_ = 0; CHECK_EQ(opts_.blank, 0);
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
auto it = std::find(vocabulary_.begin(), vocabulary_.end(), " ");
space_id_ = it - vocabulary_.begin(); space_id_ = it - vocabulary_.begin();
// if no space in vocabulary // if no space in vocabulary
if ((size_t)space_id_ >= vocabulary_.size()) { if (static_cast<size_t>(space_id_) >= vocabulary_.size()) {
space_id_ = -2; space_id_ = -2;
} }
} }
@ -84,8 +80,6 @@ void CTCBeamSearch::Decode(
return; return;
} }
int32 CTCBeamSearch::NumFrameDecoded() { return num_frame_decoded_ + 1; }
// todo rename, refactor // todo rename, refactor
void CTCBeamSearch::AdvanceDecode( void CTCBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) { const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
@ -110,17 +104,21 @@ void CTCBeamSearch::ResetPrefixes() {
} }
int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs, int CTCBeamSearch::DecodeLikelihoods(const vector<vector<float>>& probs,
vector<string>& nbest_words) { const vector<string>& nbest_words) {
kaldi::Timer timer; kaldi::Timer timer;
timer.Reset();
AdvanceDecoding(probs); AdvanceDecoding(probs);
LOG(INFO) << "ctc decoding elapsed time(s) " LOG(INFO) << "ctc decoding elapsed time(s) "
<< static_cast<float>(timer.Elapsed()) / 1000.0f; << static_cast<float>(timer.Elapsed()) / 1000.0f;
return 0; return 0;
} }
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath(int n) {
int beam_size = n == -1 ? opts_.beam_size : std::min(n, opts_.beam_size);
return get_beam_search_result(prefixes_, vocabulary_, beam_size);
}
vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() { vector<std::pair<double, string>> CTCBeamSearch::GetNBestPath() {
return get_beam_search_result(prefixes_, vocabulary_, opts_.beam_size); return GetNBestPath(-1);
} }
string CTCBeamSearch::GetBestPath() { string CTCBeamSearch::GetBestPath() {
@ -167,7 +165,7 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
continue; continue;
} }
min_cutoff = prefixes_[num_prefixes_ - 1]->score + min_cutoff = prefixes_[num_prefixes_ - 1]->score +
std::log(prob[blank_id_]) - std::log(prob[opts_.blank]) -
std::max(0.0, init_ext_scorer_->beta); std::max(0.0, init_ext_scorer_->beta);
full_beam = (num_prefixes_ == beam_size); full_beam = (num_prefixes_ == beam_size);
@ -195,9 +193,9 @@ void CTCBeamSearch::AdvanceDecoding(const vector<vector<BaseFloat>>& probs) {
for (size_t i = beam_size; i < prefixes_.size(); ++i) { for (size_t i = beam_size; i < prefixes_.size(); ++i) {
prefixes_[i]->remove(); prefixes_[i]->remove();
} }
} // if } // end if
num_frame_decoded_++; num_frame_decoded_++;
} // for probs_seq } // end for probs_seq
} }
int32 CTCBeamSearch::SearchOneChar( int32 CTCBeamSearch::SearchOneChar(
@ -215,7 +213,7 @@ int32 CTCBeamSearch::SearchOneChar(
break; break;
} }
if (c == blank_id_) { if (c == opts_.blank) {
prefix->log_prob_b_cur = prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue; continue;

@ -12,67 +12,47 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "base/common.h" // used by deepspeech2
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "kaldi/decoder/decodable-itf.h"
#include "util/parse-options.h"
#pragma once #pragma once
namespace ppspeech { #include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_decoders/path_trie.h"
#include "decoder/ctc_decoders/scorer.h"
#include "decoder/decoder_itf.h"
struct CTCBeamSearchOptions { namespace ppspeech {
std::string dict_file;
std::string lm_path;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int beam_size;
int cutoff_top_n;
int num_proc_bsearch;
CTCBeamSearchOptions()
: dict_file("vocab.txt"),
lm_path(""),
alpha(1.9f),
beta(5.0),
beam_size(300),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(10) {}
void Register(kaldi::OptionsItf* opts) {
opts->Register("dict", &dict_file, "dict file ");
opts->Register("lm-path", &lm_path, "language model file");
opts->Register("alpha", &alpha, "alpha");
opts->Register("beta", &beta, "beta");
opts->Register(
"beam-size", &beam_size, "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, "cutoff top n");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, "num proc bsearch");
}
};
class CTCBeamSearch { class CTCBeamSearch : public DecoderBase {
public: public:
explicit CTCBeamSearch(const CTCBeamSearchOptions& opts); explicit CTCBeamSearch(const CTCBeamSearchOptions& opts);
~CTCBeamSearch() {} ~CTCBeamSearch() {}
void InitDecoder(); void InitDecoder();
void Reset();
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable); void Decode(std::shared_ptr<kaldi::DecodableInterface> decodable);
std::string GetBestPath(); std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath(); std::vector<std::pair<double, std::string>> GetNBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath(int n);
std::string GetFinalBestPath(); std::string GetFinalBestPath();
int NumFrameDecoded();
std::string GetPartialResult() {
CHECK(false) << "Not implement.";
return {};
}
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs, int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words); const std::vector<std::string>& nbest_words);
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
private: private:
void ResetPrefixes(); void ResetPrefixes();
int32 SearchOneChar(const bool& full_beam, int32 SearchOneChar(const bool& full_beam,
const std::pair<size_t, BaseFloat>& log_prob_idx, const std::pair<size_t, BaseFloat>& log_prob_idx,
const BaseFloat& min_cutoff); const BaseFloat& min_cutoff);
@ -83,12 +63,11 @@ class CTCBeamSearch {
CTCBeamSearchOptions opts_; CTCBeamSearchOptions opts_;
std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later std::shared_ptr<Scorer> init_ext_scorer_; // todo separate later
std::vector<std::string> vocabulary_; // todo remove later std::vector<std::string> vocabulary_; // todo remove later
size_t blank_id_;
int space_id_; int space_id_;
std::shared_ptr<PathTrie> root_; std::shared_ptr<PathTrie> root_;
std::vector<PathTrie*> prefixes_; std::vector<PathTrie*> prefixes_;
int num_frame_decoded_;
DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch); DISALLOW_COPY_AND_ASSIGN(CTCBeamSearch);
}; };
} // namespace basr } // namespace ppspeech

@ -12,29 +12,26 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// todo refactor, repalce with gtest // used by deepspeech2
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "decoder/ctc_tlg_decoder.h" #include "decoder/ctc_beam_search_decoder.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/paddle_nnet.h" #include "nnet/ds2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm");
DEFINE_string(graph_path, "TLG", "decoder graph"); DEFINE_string(lm_path, "", "language model");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_int32(max_active, 7500, "decoder graph");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
DEFINE_int32(receptive_field_length, DEFINE_int32(receptive_field_length,
7, 7,
"receptive field of two CNN(kernel=3) downsampling module."); "receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate, DEFINE_int32(subsampling_rate,
4, 4,
"two CNN(kernel=3) module downsampling rate."); "two CNN(kernel=3) module downsampling rate.");
DEFINE_string( DEFINE_string(
@ -48,59 +45,59 @@ DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box", "chunk_state_h_box,chunk_state_c_box",
"model cache names"); "model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
using std::vector; using std::vector;
// test TLG decoder by feeding speech feature. // test ds2 online decoder by feeding speech feature
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
CHECK_NE(FLAGS_result_wspecifier, "");
CHECK_NE(FLAGS_feature_rspecifier, "");
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_graph = FLAGS_model_path; std::string model_path = FLAGS_model_path;
std::string model_params = FLAGS_param_path; std::string model_params = FLAGS_param_path;
std::string word_symbol_table = FLAGS_word_symbol_table; std::string dict_file = FLAGS_dict_file;
std::string graph_path = FLAGS_graph_path; std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "model path: " << model_graph; LOG(INFO) << "model path: " << model_path;
LOG(INFO) << "model param: " << model_params; LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "word symbol path: " << word_symbol_table; LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "graph path: " << graph_path; LOG(INFO) << "lm path: " << lm_path;
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
ppspeech::TLGDecoderOptions opts; ppspeech::CTCBeamSearchOptions opts;
opts.word_symbol_table = word_symbol_table; opts.dict_file = dict_file;
opts.fst_path = graph_path; opts.lm_path = lm_path;
opts.opts.max_active = FLAGS_max_active; ppspeech::CTCBeamSearch decoder(opts);
opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5; ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_graph;
model_opts.param_path = model_params;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet( std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts)); new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache()); std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable( std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale)); new ppspeech::Decodable(nnet, raw_data));
int32 chunk_size = FLAGS_receptive_field_length + int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length; int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length; LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder(); decoder.InitDecoder();
kaldi::Timer timer; kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) { for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key(); string utt = feature_reader.Key();
@ -132,6 +129,7 @@ int main(int argc, char* argv[]) {
if (feature_chunk_size < receptive_field_length) break; if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride; int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) { for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start); kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp( kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
@ -161,10 +159,9 @@ int main(int argc, char* argv[]) {
++num_done; ++num_done;
} }
double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s";
KALDI_LOG << "Done " << num_done << " utterances, " << num_err KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors."; << " with errors.";
double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s";
return (num_done != 0 ? 0 : 1); return (num_done != 0 ? 0 : 1);
} }

@ -0,0 +1,78 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "util/parse-options.h"
namespace ppspeech {
struct CTCBeamSearchOptions {
// common
int blank;
// ds2
std::string dict_file;
std::string lm_path;
int beam_size;
BaseFloat alpha;
BaseFloat beta;
BaseFloat cutoff_prob;
int cutoff_top_n;
int num_proc_bsearch;
// u2
int first_beam_size;
int second_beam_size;
CTCBeamSearchOptions()
: blank(0),
dict_file("vocab.txt"),
lm_path(""),
beam_size(300),
alpha(1.9f),
beta(5.0),
cutoff_prob(0.99f),
cutoff_top_n(40),
num_proc_bsearch(10),
first_beam_size(10),
second_beam_size(10) {}
void Register(kaldi::OptionsItf* opts) {
std::string module = "Ds2BeamSearchConfig: ";
opts->Register("dict", &dict_file, module + "vocab file path.");
opts->Register(
"lm-path", &lm_path, module + "ngram language model path.");
opts->Register("alpha", &alpha, module + "alpha");
opts->Register("beta", &beta, module + "beta");
opts->Register("beam-size",
&beam_size,
module + "beam size for beam search method");
opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs");
opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n");
opts->Register(
"num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch");
opts->Register("blank", &blank, "blank id, default is 0.");
module = "U2BeamSearchConfig: ";
opts->Register(
"first-beam-size", &first_beam_size, module + "first beam size.");
opts->Register("second-beam-size",
&second_beam_size,
module + "second beam size.");
}
};
} // namespace ppspeech

@ -0,0 +1,370 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_prefix_beam_search_decoder.h"
#include "absl/strings/str_join.h"
#include "base/common.h"
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "utils/math.h"
#ifdef USE_PROFILING
#include "paddle/fluid/platform/profiler.h"
using paddle::platform::RecordEvent;
using paddle::platform::TracerEventType;
#endif
namespace ppspeech {
CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string& vocab_path,
const CTCBeamSearchOptions& opts)
: opts_(opts) {
unit_table_ = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(vocab_path));
CHECK(unit_table_ != nullptr);
Reset();
}
void CTCPrefixBeamSearch::Reset() {
num_frame_decoded_ = 0;
cur_hyps_.clear();
hypotheses_.clear();
likelihood_.clear();
viterbi_likelihood_.clear();
times_.clear();
outputs_.clear();
// empty hyp with Score
std::vector<int> empty;
PrefixScore prefix_score;
prefix_score.InitEmpty();
cur_hyps_[empty] = prefix_score;
outputs_.emplace_back(empty);
hypotheses_.emplace_back(empty);
likelihood_.emplace_back(prefix_score.TotalScore());
times_.emplace_back(empty);
}
void CTCPrefixBeamSearch::InitDecoder() { Reset(); }
void CTCPrefixBeamSearch::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (1) {
// forward frame by frame
std::vector<kaldi::BaseFloat> frame_prob;
bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob);
if (flag == false) {
VLOG(1) << "decoder advance decode exit." << frame_prob.size();
break;
}
std::vector<std::vector<kaldi::BaseFloat>> likelihood;
likelihood.push_back(frame_prob);
AdvanceDecoding(likelihood);
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_;
}
}
static bool PrefixScoreCompare(
const std::pair<std::vector<int>, PrefixScore>& a,
const std::pair<std::vector<int>, PrefixScore>& b) {
// log domain
return a.second.TotalScore() > b.second.TotalScore();
}
void CTCPrefixBeamSearch::AdvanceDecoding(
const std::vector<std::vector<kaldi::BaseFloat>>& logp) {
#ifdef USE_PROFILING
RecordEvent event("CtcPrefixBeamSearch::AdvanceDecoding",
TracerEventType::UserDefined,
1);
#endif
if (logp.size() == 0) return;
int first_beam_size =
std::min(static_cast<int>(logp[0].size()), opts_.first_beam_size);
for (int t = 0; t < logp.size(); ++t, ++num_frame_decoded_) {
const std::vector<kaldi::BaseFloat>& logp_t = logp[t];
std::unordered_map<std::vector<int>, PrefixScore, PrefixScoreHash>
next_hyps;
// 1. first beam prune, only select topk candidates
std::vector<kaldi::BaseFloat> topk_score;
std::vector<int32_t> topk_index;
TopK(logp_t, first_beam_size, &topk_score, &topk_index);
VLOG(2) << "topk: " << num_frame_decoded_ << " "
<< *std::max_element(logp_t.begin(), logp_t.end()) << " "
<< topk_score[0];
for (int i = 0; i < topk_score.size(); i++) {
VLOG(2) << "topk: " << num_frame_decoded_ << " " << topk_score[i];
}
// 2. token passing
for (int i = 0; i < topk_index.size(); ++i) {
int id = topk_index[i];
auto prob = topk_score[i];
for (const auto& it : cur_hyps_) {
const std::vector<int>& prefix = it.first;
const PrefixScore& prefix_score = it.second;
// If prefix doesn't exist in next_hyps, next_hyps[prefix] will
// insert
// PrefixScore(-inf, -inf) by default, since the default
// constructor
// of PrefixScore will set fields b(blank ending Score) and
// nb(none blank ending Score) to -inf, respectively.
if (id == opts_.blank) {
// case 0: *a + <blank> => *a, *a<blank> + <blank> => *a,
// prefix not
// change
PrefixScore& next_score = next_hyps[prefix];
next_score.b =
LogSumExp(next_score.b, prefix_score.Score() + prob);
// timestamp, blank is slince, not effact timestamp
next_score.v_b = prefix_score.ViterbiScore() + prob;
next_score.times_b = prefix_score.Times();
// Prefix not changed, copy the context from pefix
if (context_graph_ && !next_score.has_context) {
next_score.CopyContext(prefix_score);
next_score.has_context = true;
}
} else if (!prefix.empty() && id == prefix.back()) {
// case 1: *a + a => *a, prefix not changed
PrefixScore& next_score1 = next_hyps[prefix];
next_score1.nb =
LogSumExp(next_score1.nb, prefix_score.nb + prob);
// timestamp, non-blank symbol effact timestamp
if (next_score1.v_nb < prefix_score.v_nb + prob) {
// compute viterbi Score
next_score1.v_nb = prefix_score.v_nb + prob;
if (next_score1.cur_token_prob < prob) {
// store max token prob
next_score1.cur_token_prob = prob;
// update this timestamp as token appeared here.
next_score1.times_nb = prefix_score.times_nb;
assert(next_score1.times_nb.size() > 0);
next_score1.times_nb.back() = num_frame_decoded_;
}
}
// Prefix not changed, copy the context from pefix
if (context_graph_ && !next_score1.has_context) {
next_score1.CopyContext(prefix_score);
next_score1.has_context = true;
}
// case 2: *a<blank> + a => *aa, prefix changed.
std::vector<int> new_prefix(prefix);
new_prefix.emplace_back(id);
PrefixScore& next_score2 = next_hyps[new_prefix];
next_score2.nb =
LogSumExp(next_score2.nb, prefix_score.b + prob);
// timestamp, non-blank symbol effact timestamp
if (next_score2.v_nb < prefix_score.v_b + prob) {
// compute viterbi Score
next_score2.v_nb = prefix_score.v_b + prob;
// new token added
next_score2.cur_token_prob = prob;
next_score2.times_nb = prefix_score.times_b;
next_score2.times_nb.emplace_back(num_frame_decoded_);
}
// Prefix changed, calculate the context Score.
if (context_graph_ && !next_score2.has_context) {
next_score2.UpdateContext(
context_graph_, prefix_score, id, prefix.size());
next_score2.has_context = true;
}
} else {
// id != prefix.back()
// case 3: *a + b => *ab, *a<blank> +b => *ab
std::vector<int> new_prefix(prefix);
new_prefix.emplace_back(id);
PrefixScore& next_score = next_hyps[new_prefix];
next_score.nb =
LogSumExp(next_score.nb, prefix_score.Score() + prob);
// timetamp, non-blank symbol effact timestamp
if (next_score.v_nb < prefix_score.ViterbiScore() + prob) {
next_score.v_nb = prefix_score.ViterbiScore() + prob;
next_score.cur_token_prob = prob;
next_score.times_nb = prefix_score.Times();
next_score.times_nb.emplace_back(num_frame_decoded_);
}
// Prefix changed, calculate the context Score.
if (context_graph_ && !next_score.has_context) {
next_score.UpdateContext(
context_graph_, prefix_score, id, prefix.size());
next_score.has_context = true;
}
}
} // end for (const auto& it : cur_hyps_)
} // end for (int i = 0; i < topk_index.size(); ++i)
// 3. second beam prune, only keep top n best paths
std::vector<std::pair<std::vector<int>, PrefixScore>> arr(
next_hyps.begin(), next_hyps.end());
int second_beam_size =
std::min(static_cast<int>(arr.size()), opts_.second_beam_size);
std::nth_element(arr.begin(),
arr.begin() + second_beam_size,
arr.end(),
PrefixScoreCompare);
arr.resize(second_beam_size);
std::sort(arr.begin(), arr.end(), PrefixScoreCompare);
// 4. update cur_hyps by next_hyps, and get new result
UpdateHypotheses(arr);
} // end for (int t = 0; t < logp.size(); ++t, ++num_frame_decoded_)
}
void CTCPrefixBeamSearch::UpdateHypotheses(
const std::vector<std::pair<std::vector<int>, PrefixScore>>& hyps) {
cur_hyps_.clear();
outputs_.clear();
hypotheses_.clear();
likelihood_.clear();
viterbi_likelihood_.clear();
times_.clear();
for (auto& item : hyps) {
cur_hyps_[item.first] = item.second;
UpdateOutputs(item);
hypotheses_.emplace_back(std::move(item.first));
likelihood_.emplace_back(item.second.TotalScore());
viterbi_likelihood_.emplace_back(item.second.ViterbiScore());
times_.emplace_back(item.second.Times());
}
}
void CTCPrefixBeamSearch::UpdateOutputs(
const std::pair<std::vector<int>, PrefixScore>& prefix) {
const std::vector<int>& input = prefix.first;
const std::vector<int>& start_boundaries = prefix.second.start_boundaries;
const std::vector<int>& end_boundaries = prefix.second.end_boundaries;
// add <context> </context> tag
std::vector<int> output;
int s = 0;
int e = 0;
for (int i = 0; i < input.size(); ++i) {
output.emplace_back(input[i]);
}
outputs_.emplace_back(output);
}
void CTCPrefixBeamSearch::FinalizeSearch() {
UpdateFinalContext();
VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_;
int cnt = 0;
for (int i = 0; i < hypotheses_.size(); i++) {
VLOG(2) << "hyp " << cnt << " len: " << hypotheses_[i].size()
<< " ctc score: " << likelihood_[i];
for (int j = 0; j < hypotheses_[i].size(); j++) {
VLOG(2) << hypotheses_[i][j];
}
}
}
void CTCPrefixBeamSearch::UpdateFinalContext() {
if (context_graph_ == nullptr) return;
CHECK(hypotheses_.size() == cur_hyps_.size());
CHECK(hypotheses_.size() == likelihood_.size());
// We should backoff the context Score/state when the context is
// not fully matched at the last time.
for (const auto& prefix : hypotheses_) {
PrefixScore& prefix_score = cur_hyps_[prefix];
if (prefix_score.context_score != 0) {
prefix_score.UpdateContext(
context_graph_, prefix_score, 0, prefix.size());
}
}
std::vector<std::pair<std::vector<int>, PrefixScore>> arr(cur_hyps_.begin(),
cur_hyps_.end());
std::sort(arr.begin(), arr.end(), PrefixScoreCompare);
// Update cur_hyps_ and get new result
UpdateHypotheses(arr);
}
std::string CTCPrefixBeamSearch::GetBestPath(int index) {
int n_hyps = Outputs().size();
CHECK_GT(n_hyps, 0);
CHECK_LT(index, n_hyps);
std::vector<int> one = Outputs()[index];
std::string sentence;
for (int i = 0; i < one.size(); i++) {
sentence += unit_table_->Find(one[i]);
}
return sentence;
}
std::string CTCPrefixBeamSearch::GetBestPath() { return GetBestPath(0); }
std::vector<std::pair<double, std::string>> CTCPrefixBeamSearch::GetNBestPath(
int n) {
int hyps_size = hypotheses_.size();
CHECK_GT(hyps_size, 0);
int min_n = n == -1 ? hypotheses_.size() : std::min(n, hyps_size);
std::vector<std::pair<double, std::string>> n_best;
n_best.reserve(min_n);
for (int i = 0; i < min_n; i++) {
n_best.emplace_back(Likelihood()[i], GetBestPath(i));
}
return n_best;
}
std::vector<std::pair<double, std::string>>
CTCPrefixBeamSearch::GetNBestPath() {
return GetNBestPath(-1);
}
std::string CTCPrefixBeamSearch::GetFinalBestPath() { return GetBestPath(); }
std::string CTCPrefixBeamSearch::GetPartialResult() { return GetBestPath(); }
} // namespace ppspeech

@ -0,0 +1,101 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/ctc_prefix_beam_search.cc
#pragma once
#include "decoder/ctc_beam_search_opt.h"
#include "decoder/ctc_prefix_beam_search_score.h"
#include "decoder/decoder_itf.h"
#include "fst/symbol-table.h"
namespace ppspeech {
class ContextGraph;
class CTCPrefixBeamSearch : public DecoderBase {
public:
CTCPrefixBeamSearch(const std::string& vocab_path,
const CTCBeamSearchOptions& opts);
~CTCPrefixBeamSearch() {}
SearchType Type() const { return SearchType::kPrefixBeamSearch; }
void InitDecoder() override;
void Reset() override;
void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) override;
std::string GetFinalBestPath() override;
std::string GetPartialResult() override;
void FinalizeSearch();
const std::shared_ptr<fst::SymbolTable> VocabTable() const {
return unit_table_;
}
const std::vector<std::vector<int>>& Inputs() const { return hypotheses_; }
const std::vector<std::vector<int>>& Outputs() const { return outputs_; }
const std::vector<float>& Likelihood() const { return likelihood_; }
const std::vector<float>& ViterbiLikelihood() const {
return viterbi_likelihood_;
}
const std::vector<std::vector<int>>& Times() const { return times_; }
protected:
std::string GetBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath(int n) override;
private:
std::string GetBestPath(int index);
void AdvanceDecoding(
const std::vector<std::vector<kaldi::BaseFloat>>& logp);
void UpdateOutputs(const std::pair<std::vector<int>, PrefixScore>& prefix);
void UpdateHypotheses(
const std::vector<std::pair<std::vector<int>, PrefixScore>>& prefix);
void UpdateFinalContext();
private:
CTCBeamSearchOptions opts_;
std::shared_ptr<fst::SymbolTable> unit_table_{nullptr};
std::unordered_map<std::vector<int>, PrefixScore, PrefixScoreHash>
cur_hyps_;
// n-best list and corresponding likelihood, in sorted order
std::vector<std::vector<int>> hypotheses_;
std::vector<float> likelihood_;
std::vector<std::vector<int>> times_;
std::vector<float> viterbi_likelihood_;
// Outputs contain the hypotheses_ and tags lik: <context> and </context>
std::vector<std::vector<int>> outputs_;
std::shared_ptr<ContextGraph> context_graph_{nullptr};
DISALLOW_COPY_AND_ASSIGN(CTCPrefixBeamSearch);
};
} // namespace ppspeech

@ -12,40 +12,29 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// todo refactor, repalce with gtest #include "absl/strings/str_split.h"
#include "base/common.h"
#include "base/flags.h" #include "decoder/ctc_prefix_beam_search_decoder.h"
#include "base/log.h"
#include "decoder/ctc_beam_search_decoder.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "fst/symbol-table.h"
#include "kaldi/util/table-types.h" #include "kaldi/util/table-types.h"
#include "nnet/decodable.h" #include "nnet/decodable.h"
#include "nnet/paddle_nnet.h" #include "nnet/u2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(vocab_path, "", "vocab path");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); DEFINE_string(model_path, "", "paddle nnet model");
DEFINE_string(lm_path, "", "language model");
DEFINE_int32(receptive_field_length, DEFINE_int32(receptive_field_length,
7, 7,
"receptive field of two CNN(kernel=3) downsampling module."); "receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate, DEFINE_int32(subsampling_rate,
4, 4,
"two CNN(kernel=3) module downsampling rate."); "two CNN(kernel=3) module downsampling rate.");
DEFINE_string(
model_input_names, DEFINE_int32(nnet_decoder_chunk, 16, "paddle nnet forward chunk");
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names,
"chunk_state_h_box,chunk_state_c_box",
"model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Matrix; using kaldi::Matrix;
@ -53,117 +42,138 @@ using std::vector;
// test ds2 online decoder by feeding speech feature // test ds2 online decoder by feeding speech feature
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
CHECK(FLAGS_result_wspecifier != ""); int32 num_done = 0, num_err = 0;
CHECK(FLAGS_feature_rspecifier != "");
CHECK_NE(FLAGS_result_wspecifier, "");
CHECK_NE(FLAGS_feature_rspecifier, "");
CHECK_NE(FLAGS_vocab_path, "");
CHECK_NE(FLAGS_model_path, "");
LOG(INFO) << "model path: " << FLAGS_model_path;
LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path;
kaldi::SequentialBaseFloatMatrixReader feature_reader( kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier); FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
std::string model_path = FLAGS_model_path;
std::string model_params = FLAGS_param_path;
std::string dict_file = FLAGS_dict_file;
std::string lm_path = FLAGS_lm_path;
LOG(INFO) << "model path: " << model_path;
LOG(INFO) << "model param: " << model_params;
LOG(INFO) << "dict path: " << dict_file;
LOG(INFO) << "lm path: " << lm_path;
int32 num_done = 0, num_err = 0; // nnet
ppspeech::ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
std::shared_ptr<ppspeech::U2Nnet> nnet =
std::make_shared<ppspeech::U2Nnet>(model_opts);
// decodeable
std::shared_ptr<ppspeech::DataCache> raw_data =
std::make_shared<ppspeech::DataCache>();
std::shared_ptr<ppspeech::Decodable> decodable =
std::make_shared<ppspeech::Decodable>(nnet, raw_data);
// decoder
ppspeech::CTCBeamSearchOptions opts; ppspeech::CTCBeamSearchOptions opts;
opts.dict_file = dict_file; opts.blank = 0;
opts.lm_path = lm_path; opts.first_beam_size = 10;
ppspeech::CTCBeamSearch decoder(opts); opts.second_beam_size = 10;
ppspeech::CTCPrefixBeamSearch decoder(FLAGS_vocab_path, opts);
ppspeech::ModelOptions model_opts;
model_opts.model_path = model_path;
model_opts.param_path = model_params;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data));
int32 chunk_size = FLAGS_receptive_field_length + int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_downsampling_rate; (FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_downsampling_rate * FLAGS_nnet_decoder_chunk; int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length; int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size; LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride; LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length; LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder(); decoder.InitDecoder();
kaldi::Timer timer; kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) { for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key(); string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value(); kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();
int32 row_idx = 0; int nframes = feature.NumRows();
int32 padding_len = 0; int feat_dim = feature.NumCols();
raw_data->SetDim(feat_dim);
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "feat shape: " << nframes << ", " << feat_dim;
raw_data->SetDim(feat_dim);
int32 ori_feature_len = feature.NumRows(); int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) { int32 num_chunks = feature.NumRows() / chunk_stride + 1;
padding_len = LOG(INFO) << "num_chunks: " << num_chunks;
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size * int32 this_chunk_size = 0;
feature.NumCols());
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) { if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min( this_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size); ori_feature_len - chunk_idx * chunk_stride, chunk_size);
} }
if (feature_chunk_size < receptive_field_length) break; if (this_chunk_size < receptive_field_length) {
LOG(WARNING)
<< "utt: " << utt << " skip last " << this_chunk_size
<< " frames, expect is " << receptive_field_length;
break;
}
kaldi::Vector<kaldi::BaseFloat> feature_chunk(this_chunk_size *
feat_dim);
int32 start = chunk_idx * chunk_stride; int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < this_chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> feat_row(feature, start);
kaldi::SubVector<kaldi::BaseFloat> feature_chunk_row(
feature_chunk.Data() + row_id * feat_dim, feat_dim);
for (int row_id = 0; row_id < chunk_size; ++row_id) { feature_chunk_row.CopyFromVec(feat_row);
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
++start; ++start;
} }
// feat to frontend pipeline cache
raw_data->Accept(feature_chunk); raw_data->Accept(feature_chunk);
// send data finish signal
if (chunk_idx == num_chunks - 1) { if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished(); raw_data->SetFinished();
} }
// forward nnet
decoder.AdvanceDecode(decodable); decoder.AdvanceDecode(decodable);
LOG(INFO) << "Partial result: " << decoder.GetPartialResult();
} }
std::string result;
result = decoder.GetFinalBestPath(); decoder.FinalizeSearch();
// get 1-best result
std::string result = decoder.GetFinalBestPath();
// after process one utt, then reset state.
decodable->Reset(); decodable->Reset();
decoder.Reset(); decoder.Reset();
if (result.empty()) { if (result.empty()) {
// the TokenWriter can not write empty string. // the TokenWriter can not write empty string.
++num_err; ++num_err;
KALDI_LOG << " the result of " << utt << " is empty"; LOG(INFO) << " the result of " << utt << " is empty";
continue; continue;
} }
KALDI_LOG << " the result of " << utt << " is " << result;
LOG(INFO) << " the result of " << utt << " is " << result;
result_writer.Write(utt, result); result_writer.Write(utt, result);
++num_done; ++num_done;
} }
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
double elapsed = timer.Elapsed(); double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s"; LOG(INFO) << "Program cost:" << elapsed << " sec";
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1); return (num_done != 0 ? 0 : 1);
} }

@ -0,0 +1,98 @@
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from
// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/ctc_prefix_beam_search.h
#pragma once
#include "base/common.h"
#include "utils/math.h"
namespace ppspeech {
class ContextGraph;
struct PrefixScore {
// decoding, unit in log scale
float b = -kBaseFloatMax; // blank ending score
float nb = -kBaseFloatMax; // none-blank ending score
// decoding score, sum
float Score() const { return LogSumExp(b, nb); }
// timestamp, unit in log sclae
float v_b = -kBaseFloatMax; // viterbi blank ending score
float v_nb = -kBaseFloatMax; // niterbi none-blank ending score
float cur_token_prob = -kBaseFloatMax; // prob of current token
std::vector<int> times_b; // times of viterbi blank path
std::vector<int> times_nb; // times of viterbi non-blank path
// timestamp score, max
float ViterbiScore() const { return std::max(v_b, v_nb); }
// get timestamp
const std::vector<int>& Times() const {
return v_b > v_nb ? times_b : times_nb;
}
// context state
bool has_context = false;
int context_state = 0;
float context_score = 0;
std::vector<int> start_boundaries;
std::vector<int> end_boundaries;
// decodign score with context bias
float TotalScore() const { return Score() + context_score; }
void CopyContext(const PrefixScore& prefix_score) {
context_state = prefix_score.context_state;
context_score = prefix_score.context_score;
start_boundaries = prefix_score.start_boundaries;
end_boundaries = prefix_score.end_boundaries;
}
void UpdateContext(const std::shared_ptr<ContextGraph>& constext_graph,
const PrefixScore& prefix_score,
int word_id,
int prefix_len) {
CHECK(false);
}
void InitEmpty() {
b = 0.0f; // log(1)
nb = -kBaseFloatMax; // log(0)
v_b = 0.0f; // log(1)
v_nb = 0.0f; // log(1)
}
};
struct PrefixScoreHash {
// https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector
std::size_t operator()(const std::vector<int>& prefix) const {
std::size_t seed = prefix.size();
for (auto& i : prefix) {
seed ^= i + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
};
using PrefixWithScoreType = std::pair<std::vector<int>, PrefixScoreHash>;
} // namespace ppspeech

@ -18,37 +18,38 @@ namespace ppspeech {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path)); fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
CHECK(fst_ != nullptr); CHECK(fst_ != nullptr);
word_symbol_table_.reset( word_symbol_table_.reset(
fst::SymbolTable::ReadText(opts.word_symbol_table)); fst::SymbolTable::ReadText(opts.word_symbol_table));
decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts)); decoder_.reset(new kaldi::LatticeFasterOnlineDecoder(*fst_, opts.opts));
decoder_->InitDecoding();
frame_decoded_size_ = 0; Reset();
} }
void TLGDecoder::InitDecoder() { void TLGDecoder::Reset() {
decoder_->InitDecoding(); decoder_->InitDecoding();
frame_decoded_size_ = 0; num_frame_decoded_ = 0;
return;
} }
void TLGDecoder::InitDecoder() { Reset(); }
void TLGDecoder::AdvanceDecode( void TLGDecoder::AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) { const std::shared_ptr<kaldi::DecodableInterface>& decodable) {
while (!decodable->IsLastFrame(frame_decoded_size_)) { while (!decodable->IsLastFrame(num_frame_decoded_)) {
AdvanceDecoding(decodable.get()); AdvanceDecoding(decodable.get());
} }
} }
void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) { void TLGDecoder::AdvanceDecoding(kaldi::DecodableInterface* decodable) {
decoder_->AdvanceDecoding(decodable, 1); decoder_->AdvanceDecoding(decodable, 1);
frame_decoded_size_++; num_frame_decoded_++;
} }
void TLGDecoder::Reset() {
InitDecoder();
return;
}
std::string TLGDecoder::GetPartialResult() { std::string TLGDecoder::GetPartialResult() {
if (frame_decoded_size_ == 0) { if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.") // BestPathEnd if no frames were decoded.")
return std::string(""); return std::string("");
@ -68,7 +69,7 @@ std::string TLGDecoder::GetPartialResult() {
} }
std::string TLGDecoder::GetFinalBestPath() { std::string TLGDecoder::GetFinalBestPath() {
if (frame_decoded_size_ == 0) { if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.") // BestPathEnd if no frames were decoded.")
return std::string(""); return std::string("");
@ -88,4 +89,5 @@ std::string TLGDecoder::GetFinalBestPath() {
} }
return words; return words;
} }
}
} // namespace ppspeech

@ -14,37 +14,78 @@
#pragma once #pragma once
#include "base/basic_types.h" #include "base/common.h"
#include "kaldi/decoder/decodable-itf.h" #include "decoder/decoder_itf.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h" #include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h" #include "util/parse-options.h"
DECLARE_string(graph_path);
DECLARE_string(word_symbol_table);
DECLARE_int32(max_active);
DECLARE_double(beam);
DECLARE_double(lattice_beam);
namespace ppspeech { namespace ppspeech {
struct TLGDecoderOptions { struct TLGDecoderOptions {
kaldi::LatticeFasterDecoderConfig opts; kaldi::LatticeFasterDecoderConfig opts{};
// todo remove later, add into decode resource // todo remove later, add into decode resource
std::string word_symbol_table; std::string word_symbol_table;
std::string fst_path; std::string fst_path;
TLGDecoderOptions() : word_symbol_table(""), fst_path("") {} static TLGDecoderOptions InitFromFlags() {
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
LOG(INFO) << "fst path: " << decoder_opts.fst_path;
LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
LOG(INFO) << "LatticeFasterDecoder max active: "
<< decoder_opts.opts.max_active;
LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam;
LOG(INFO) << "LatticeFasterDecoder lattice_beam: "
<< decoder_opts.opts.lattice_beam;
return decoder_opts;
}
}; };
class TLGDecoder { class TLGDecoder : public DecoderBase {
public: public:
explicit TLGDecoder(TLGDecoderOptions opts); explicit TLGDecoder(TLGDecoderOptions opts);
~TLGDecoder() = default;
void InitDecoder(); void InitDecoder();
void Decode(); void Reset();
std::string GetBestPath();
std::vector<std::pair<double, std::string>> GetNBestPath();
std::string GetFinalBestPath();
std::string GetPartialResult();
int NumFrameDecoded();
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
std::vector<std::string>& nbest_words);
void AdvanceDecode( void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable); const std::shared_ptr<kaldi::DecodableInterface>& decodable);
void Reset();
void Decode();
std::string GetFinalBestPath() override;
std::string GetPartialResult() override;
int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
const std::vector<std::string>& nbest_words);
protected:
std::string GetBestPath() override {
CHECK(false);
return {};
}
std::vector<std::pair<double, std::string>> GetNBestPath() override {
CHECK(false);
return {};
}
std::vector<std::pair<double, std::string>> GetNBestPath(int n) override {
CHECK(false);
return {};
}
private: private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable); void AdvanceDecoding(kaldi::DecodableInterface* decodable);
@ -52,8 +93,6 @@ class TLGDecoder {
std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_; std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_; std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_; std::shared_ptr<fst::SymbolTable> word_symbol_table_;
// the frame size which have decoded starts from 0.
int32 frame_decoded_size_;
}; };

@ -0,0 +1,137 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// todo refactor, repalce with gtest
#include "base/common.h"
#include "decoder/ctc_tlg_decoder.h"
#include "decoder/param.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/ds2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
using kaldi::BaseFloat;
using kaldi::Matrix;
using std::vector;
// test TLG decoder by feeding speech feature.
int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);
int32 num_done = 0, num_err = 0;
ppspeech::TLGDecoderOptions opts =
ppspeech::TLGDecoderOptions::InitFromFlags();
opts.opts.beam = 15.0;
opts.opts.lattice_beam = 7.5;
ppspeech::TLGDecoder decoder(opts);
ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));
int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
decoder.InitDecoder();
kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();
int32 row_idx = 0;
int32 padding_len = 0;
int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
padding_len =
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (feature_chunk_size < receptive_field_length) break;
int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
++start;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
decoder.AdvanceDecode(decodable);
}
std::string result;
result = decoder.GetFinalBestPath();
decodable->Reset();
decoder.Reset();
if (result.empty()) {
// the TokenWriter can not write empty string.
++num_err;
KALDI_LOG << " the result of " << utt << " is empty";
continue;
}
KALDI_LOG << " the result of " << utt << " is " << result;
result_writer.Write(utt, result);
++num_done;
}
double elapsed = timer.Elapsed();
KALDI_LOG << " cost:" << elapsed << " s";
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
<< " with errors.";
return (num_done != 0 ? 0 : 1);
}

@ -0,0 +1,66 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "base/common.h"
#include "kaldi/decoder/decodable-itf.h"
namespace ppspeech {
enum SearchType {
kPrefixBeamSearch = 0,
kWfstBeamSearch = 1,
};
class DecoderInterface {
public:
virtual ~DecoderInterface() {}
virtual void InitDecoder() = 0;
virtual void Reset() = 0;
// call AdvanceDecoding
virtual void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable) = 0;
// call GetBestPath
virtual std::string GetFinalBestPath() = 0;
virtual std::string GetPartialResult() = 0;
protected:
// virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0;
// virtual void Decode() = 0;
virtual std::string GetBestPath() = 0;
virtual std::vector<std::pair<double, std::string>> GetNBestPath() = 0;
virtual std::vector<std::pair<double, std::string>> GetNBestPath(int n) = 0;
};
class DecoderBase : public DecoderInterface {
protected:
// start from one
int NumFrameDecoded() { return num_frame_decoded_ + 1; }
protected:
// current decoding frame number, abs_time_step_
int32 num_frame_decoded_;
};
} // namespace ppspeech

@ -30,8 +30,11 @@ using std::vector;
// test decoder by feeding nnet posterior probability // test decoder by feeding nnet posterior probability
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialBaseFloatMatrixReader likelihood_reader( kaldi::SequentialBaseFloatMatrixReader likelihood_reader(
FLAGS_nnet_prob_respecifier); FLAGS_nnet_prob_respecifier);

@ -17,23 +17,29 @@
#include "base/common.h" #include "base/common.h"
#include "decoder/ctc_beam_search_decoder.h" #include "decoder/ctc_beam_search_decoder.h"
#include "decoder/ctc_tlg_decoder.h" #include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
// feature // feature
DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); DEFINE_bool(use_fbank, false, "False for fbank; or linear feature");
DEFINE_bool(fill_zero,
false,
"fill zero at last chunk, when chunk < chunk_size");
// DEFINE_bool(to_float32, true, "audio convert to pcm32. True for linear // DEFINE_bool(to_float32, true, "audio convert to pcm32. True for linear
// feature, or fbank"); // feature, or fbank");
DEFINE_int32(num_bins, 161, "num bins of mel"); DEFINE_int32(num_bins, 161, "num bins of mel");
DEFINE_string(cmvn_file, "", "read cmvn"); DEFINE_string(cmvn_file, "", "read cmvn");
// feature sliding window // feature sliding window
DEFINE_int32(receptive_field_length, DEFINE_int32(receptive_field_length,
7, 7,
"receptive field of two CNN(kernel=3) downsampling module."); "receptive field of two CNN(kernel=3) downsampling module.");
DEFINE_int32(downsampling_rate, DEFINE_int32(subsampling_rate,
4, 4,
"two CNN(kernel=3) module downsampling rate."); "two CNN(kernel=3) module downsampling rate.");
DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk");
// nnet // nnet
DEFINE_string(vocab_path, "", "nnet vocab path.");
DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model");
DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param");
DEFINE_string( DEFINE_string(
@ -48,71 +54,30 @@ DEFINE_string(model_cache_names,
"model cache names"); "model cache names");
DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
// decoder // decoder
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); DEFINE_double(acoustic_scale, 1.0, "acoustic scale");
DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_int32(max_active, 7500, "max active"); DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam"); DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam"); DEFINE_double(lattice_beam, 7.5, "decoder beam");
namespace ppspeech {
// todo refactor later
FeaturePipelineOptions InitFeaturePipelineOptions() {
FeaturePipelineOptions opts;
opts.cmvn_file = FLAGS_cmvn_file;
kaldi::FrameExtractionOptions frame_opts;
frame_opts.dither = 0.0;
frame_opts.frame_shift_ms = 10;
opts.use_fbank = FLAGS_use_fbank;
if (opts.use_fbank) {
opts.to_float32 = false;
frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25;
opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
opts.fbank_opts.frame_opts = frame_opts;
} else {
opts.to_float32 = true;
frame_opts.remove_dc_offset = false;
frame_opts.frame_length_ms = 20;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
}
opts.assembler_opts.subsampling_rate = FLAGS_downsampling_rate;
opts.assembler_opts.receptive_filed_length = FLAGS_receptive_field_length;
opts.assembler_opts.nnet_decoder_chunk = FLAGS_nnet_decoder_chunk;
return opts;
}
ModelOptions InitModelOptions() {
ModelOptions model_opts;
model_opts.model_path = FLAGS_model_path;
model_opts.param_path = FLAGS_param_path;
model_opts.cache_names = FLAGS_model_cache_names;
model_opts.cache_shape = FLAGS_model_cache_shapes;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
return model_opts;
}
TLGDecoderOptions InitDecoderOptions() {
TLGDecoderOptions decoder_opts;
decoder_opts.word_symbol_table = FLAGS_word_symbol_table;
decoder_opts.fst_path = FLAGS_graph_path;
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
return decoder_opts;
}
RecognizerResource InitRecognizerResoure() { // DecodeOptions flags
RecognizerResource resource; // DEFINE_int32(chunk_size, -1, "decoding chunk size");
resource.acoustic_scale = FLAGS_acoustic_scale; DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
resource.feature_pipeline_opts = InitFeaturePipelineOptions(); DEFINE_double(ctc_weight,
resource.model_opts = InitModelOptions(); 0.5,
resource.tlg_opts = InitDecoderOptions(); "ctc weight when combining ctc score and rescoring score");
return resource; DEFINE_double(rescoring_weight,
} 1.0,
} "rescoring weight when combining ctc score and rescoring score");
DEFINE_double(reverse_weight,
0.3,
"used for bitransformer rescoring. it must be 0.0 if decoder is"
"conventional transformer decoder, and only reverse_weight > 0.0"
"dose the right to left decoder will be calculated and used");
DEFINE_int32(nbest, 10, "nbest for ctc wfst or prefix search");
DEFINE_int32(blank, 0, "blank id in vocab");

@ -1,5 +1,3 @@
project(frontend)
add_library(frontend STATIC add_library(frontend STATIC
cmvn.cc cmvn.cc
db_norm.cc db_norm.cc

@ -16,16 +16,18 @@
namespace ppspeech { namespace ppspeech {
using kaldi::BaseFloat;
using kaldi::Vector; using kaldi::Vector;
using kaldi::VectorBase; using kaldi::VectorBase;
using kaldi::BaseFloat;
using std::unique_ptr; using std::unique_ptr;
Assembler::Assembler(AssemblerOptions opts, Assembler::Assembler(AssemblerOptions opts,
unique_ptr<FrontendInterface> base_extractor) { unique_ptr<FrontendInterface> base_extractor) {
fill_zero_ = opts.fill_zero;
frame_chunk_stride_ = opts.subsampling_rate * opts.nnet_decoder_chunk; frame_chunk_stride_ = opts.subsampling_rate * opts.nnet_decoder_chunk;
frame_chunk_size_ = (opts.nnet_decoder_chunk - 1) * opts.subsampling_rate + frame_chunk_size_ = (opts.nnet_decoder_chunk - 1) * opts.subsampling_rate +
opts.receptive_filed_length; opts.receptive_filed_length;
cache_size_ = frame_chunk_size_ - frame_chunk_stride_;
receptive_filed_length_ = opts.receptive_filed_length; receptive_filed_length_ = opts.receptive_filed_length;
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
dim_ = base_extractor_->Dim(); dim_ = base_extractor_->Dim();
@ -38,49 +40,83 @@ void Assembler::Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
// pop feature chunk // pop feature chunk
bool Assembler::Read(kaldi::Vector<kaldi::BaseFloat>* feats) { bool Assembler::Read(kaldi::Vector<kaldi::BaseFloat>* feats) {
feats->Resize(dim_ * frame_chunk_size_);
bool result = Compute(feats); bool result = Compute(feats);
return result; return result;
} }
// read all data from base_feature_extractor_ into cache_ // read frame by frame from base_feature_extractor_ into cache_
bool Assembler::Compute(Vector<BaseFloat>* feats) { bool Assembler::Compute(Vector<BaseFloat>* feats) {
// compute and feed // compute and feed frame by frame
bool result = false;
while (feature_cache_.size() < frame_chunk_size_) { while (feature_cache_.size() < frame_chunk_size_) {
Vector<BaseFloat> feature; Vector<BaseFloat> feature;
result = base_extractor_->Read(&feature); bool result = base_extractor_->Read(&feature);
if (result == false || feature.Dim() == 0) { if (result == false || feature.Dim() == 0) {
if (IsFinished() == false) return false; VLOG(1) << "result: " << result
<< " feature dim: " << feature.Dim();
if (IsFinished() == false) {
VLOG(1) << "finished reading feature. cache size: "
<< feature_cache_.size();
return false;
} else {
VLOG(1) << "break";
break; break;
} }
}
CHECK(feature.Dim() == dim_);
feature_cache_.push(feature); feature_cache_.push(feature);
nframes_ += 1;
VLOG(1) << "nframes: " << nframes_;
} }
if (feature_cache_.size() < receptive_filed_length_) { if (feature_cache_.size() < receptive_filed_length_) {
VLOG(1) << "feature_cache less than receptive_filed_lenght. "
<< feature_cache_.size() << ": " << receptive_filed_length_;
return false; return false;
} }
if (fill_zero_) {
while (feature_cache_.size() < frame_chunk_size_) { while (feature_cache_.size() < frame_chunk_size_) {
Vector<BaseFloat> feature(dim_, kaldi::kSetZero); Vector<BaseFloat> feature(dim_, kaldi::kSetZero);
nframes_ += 1;
feature_cache_.push(feature); feature_cache_.push(feature);
} }
}
int32 this_chunk_size =
std::min(static_cast<int32>(feature_cache_.size()), frame_chunk_size_);
feats->Resize(dim_ * this_chunk_size);
VLOG(1) << "read " << this_chunk_size << " feat.";
int32 counter = 0; int32 counter = 0;
int32 cache_size = frame_chunk_size_ - frame_chunk_stride_; while (counter < this_chunk_size) {
int32 elem_dim = base_extractor_->Dim();
while (counter < frame_chunk_size_) {
Vector<BaseFloat>& val = feature_cache_.front(); Vector<BaseFloat>& val = feature_cache_.front();
int32 start = counter * elem_dim; CHECK(val.Dim() == dim_) << val.Dim();
feats->Range(start, elem_dim).CopyFromVec(val);
if (frame_chunk_size_ - counter <= cache_size) { int32 start = counter * dim_;
feats->Range(start, dim_).CopyFromVec(val);
if (this_chunk_size - counter <= cache_size_) {
feature_cache_.push(val); feature_cache_.push(val);
} }
// val is reference, so we should pop here
feature_cache_.pop(); feature_cache_.pop();
counter++; counter++;
} }
CHECK(feature_cache_.size() == cache_size_);
return result; return true;
}
void Assembler::Reset() {
std::queue<kaldi::Vector<kaldi::BaseFloat>> empty;
std::swap(feature_cache_, empty);
nframes_ = 0;
base_extractor_->Reset();
} }
} // namespace ppspeech } // namespace ppspeech

@ -22,14 +22,11 @@ namespace ppspeech {
struct AssemblerOptions { struct AssemblerOptions {
// refer:https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/paddlespeech/s2t/exps/deepspeech2/model.py // refer:https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/paddlespeech/s2t/exps/deepspeech2/model.py
// the nnet batch forward // the nnet batch forward
int32 receptive_filed_length; int32 receptive_filed_length{1};
int32 subsampling_rate; int32 subsampling_rate{1};
int32 nnet_decoder_chunk; int32 nnet_decoder_chunk{1};
bool fill_zero{false}; // whether fill zero when last chunk is not equal to
AssemblerOptions() // frame_chunk_size_
: receptive_filed_length(1),
subsampling_rate(1),
nnet_decoder_chunk(1) {}
}; };
class Assembler : public FrontendInterface { class Assembler : public FrontendInterface {
@ -39,29 +36,34 @@ class Assembler : public FrontendInterface {
std::unique_ptr<FrontendInterface> base_extractor = NULL); std::unique_ptr<FrontendInterface> base_extractor = NULL);
// Feed feats or waves // Feed feats or waves
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs); void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) override;
// feats size = num_frames * feat_dim // feats size = num_frames * feat_dim
virtual bool Read(kaldi::Vector<kaldi::BaseFloat>* feats); bool Read(kaldi::Vector<kaldi::BaseFloat>* feats) override;
// feat dim // feat dim
virtual size_t Dim() const { return dim_; } size_t Dim() const override { return dim_; }
virtual void SetFinished() { base_extractor_->SetFinished(); } void SetFinished() override { base_extractor_->SetFinished(); }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } bool IsFinished() const override { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); } void Reset() override;
private: private:
bool Compute(kaldi::Vector<kaldi::BaseFloat>* feats); bool Compute(kaldi::Vector<kaldi::BaseFloat>* feats);
int32 dim_; bool fill_zero_{false};
int32 dim_; // feat dim
int32 frame_chunk_size_; // window int32 frame_chunk_size_; // window
int32 frame_chunk_stride_; // stride int32 frame_chunk_stride_; // stride
int32 cache_size_; // window - stride
int32 receptive_filed_length_; int32 receptive_filed_length_;
std::queue<kaldi::Vector<kaldi::BaseFloat>> feature_cache_; std::queue<kaldi::Vector<kaldi::BaseFloat>> feature_cache_;
std::unique_ptr<FrontendInterface> base_extractor_; std::unique_ptr<FrontendInterface> base_extractor_;
int32 nframes_; // num frame computed
DISALLOW_COPY_AND_ASSIGN(Assembler); DISALLOW_COPY_AND_ASSIGN(Assembler);
}; };

@ -13,13 +13,14 @@
// limitations under the License. // limitations under the License.
#include "frontend/audio/audio_cache.h" #include "frontend/audio/audio_cache.h"
#include "kaldi/base/timer.h" #include "kaldi/base/timer.h"
namespace ppspeech { namespace ppspeech {
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::VectorBase;
using kaldi::Vector; using kaldi::Vector;
using kaldi::VectorBase;
AudioCache::AudioCache(int buffer_size, bool to_float32) AudioCache::AudioCache(int buffer_size, bool to_float32)
: finished_(false), : finished_(false),
@ -83,6 +84,10 @@ bool AudioCache::Read(Vector<BaseFloat>* waves) {
} }
size_ -= chunk_size; size_ -= chunk_size;
offset_ = (offset_ + chunk_size) % ring_buffer_.size(); offset_ = (offset_ + chunk_size) % ring_buffer_.size();
nsamples_ += chunk_size;
VLOG(1) << "nsamples readed: " << nsamples_;
ready_feed_condition_.notify_one(); ready_feed_condition_.notify_one();
return true; return true;
} }

@ -41,10 +41,11 @@ class AudioCache : public FrontendInterface {
virtual bool IsFinished() const { return finished_; } virtual bool IsFinished() const { return finished_; }
virtual void Reset() { void Reset() override {
offset_ = 0; offset_ = 0;
size_ = 0; size_ = 0;
finished_ = false; finished_ = false;
nsamples_ = 0;
} }
private: private:
@ -61,6 +62,7 @@ class AudioCache : public FrontendInterface {
kaldi::int32 timeout_; // millisecond kaldi::int32 timeout_; // millisecond
bool to_float32_; // int16 -> float32. used in linear_spectrogram bool to_float32_; // int16 -> float32. used in linear_spectrogram
int32 nsamples_; // number samples readed.
DISALLOW_COPY_AND_ASSIGN(AudioCache); DISALLOW_COPY_AND_ASSIGN(AudioCache);
}; };

@ -14,22 +14,25 @@
#include "frontend/audio/cmvn.h" #include "frontend/audio/cmvn.h"
#include "kaldi/feat/cmvn.h" #include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h" #include "kaldi/util/kaldi-io.h"
namespace ppspeech { namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr; using std::unique_ptr;
using std::vector;
CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor) CMVN::CMVN(std::string cmvn_file, unique_ptr<FrontendInterface> base_extractor)
: var_norm_(true) { : var_norm_(true) {
CHECK_NE(cmvn_file, "");
base_extractor_ = std::move(base_extractor); base_extractor_ = std::move(base_extractor);
bool binary; bool binary;
kaldi::Input ki(cmvn_file, &binary); kaldi::Input ki(cmvn_file, &binary);
stats_.Read(ki.Stream(), binary); stats_.Read(ki.Stream(), binary);
@ -55,11 +58,11 @@ bool CMVN::Read(kaldi::Vector<BaseFloat>* feats) {
// feats contain num_frames feature. // feats contain num_frames feature.
void CMVN::Compute(VectorBase<BaseFloat>* feats) const { void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
KALDI_ASSERT(feats != NULL); KALDI_ASSERT(feats != NULL);
int32 dim = stats_.NumCols() - 1;
if (stats_.NumRows() > 2 || stats_.NumRows() < 1 || if (stats_.NumRows() > 2 || stats_.NumRows() < 1 ||
feats->Dim() % dim != 0) { feats->Dim() % dim_ != 0) {
KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << 'x' KALDI_ERR << "Dim mismatch: cmvn " << stats_.NumRows() << ','
<< stats_.NumCols() << ", feats " << feats->Dim() << 'x'; << stats_.NumCols() - 1 << ", feats " << feats->Dim() << 'x';
} }
if (stats_.NumRows() == 1 && var_norm_) { if (stats_.NumRows() == 1 && var_norm_) {
KALDI_ERR KALDI_ERR
@ -67,7 +70,7 @@ void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
<< "are supplied."; << "are supplied.";
} }
double count = stats_(0, dim); double count = stats_(0, dim_);
// Do not change the threshold of 1.0 here: in the balanced-cmvn code, when // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when
// computing an offset and representing it as stats_, we use a count of one. // computing an offset and representing it as stats_, we use a count of one.
if (count < 1.0) if (count < 1.0)
@ -77,14 +80,14 @@ void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
if (!var_norm_) { if (!var_norm_) {
Vector<BaseFloat> offset(feats->Dim()); Vector<BaseFloat> offset(feats->Dim());
SubVector<double> mean_stats(stats_.RowData(0), dim); SubVector<double> mean_stats(stats_.RowData(0), dim_);
Vector<double> mean_stats_apply(feats->Dim()); Vector<double> mean_stats_apply(feats->Dim());
// fill the datat of mean_stats in mean_stats_appy whose dim is equal // fill the datat of mean_stats in mean_stats_appy whose dim_ is equal
// with the dim of feature. // with the dim_ of feature.
// the dim of feats = dim * num_frames; // the dim_ of feats = dim_ * num_frames;
for (int32 idx = 0; idx < feats->Dim() / dim; ++idx) { for (int32 idx = 0; idx < feats->Dim() / dim_; ++idx) {
SubVector<double> stats_tmp(mean_stats_apply.Data() + dim * idx, SubVector<double> stats_tmp(mean_stats_apply.Data() + dim_ * idx,
dim); dim_);
stats_tmp.CopyFromVec(mean_stats); stats_tmp.CopyFromVec(mean_stats);
} }
offset.AddVec(-1.0 / count, mean_stats_apply); offset.AddVec(-1.0 / count, mean_stats_apply);
@ -94,7 +97,7 @@ void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
// norm(0, d) = mean offset; // norm(0, d) = mean offset;
// norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d).
kaldi::Matrix<BaseFloat> norm(2, feats->Dim()); kaldi::Matrix<BaseFloat> norm(2, feats->Dim());
for (int32 d = 0; d < dim; d++) { for (int32 d = 0; d < dim_; d++) {
double mean, offset, scale; double mean, offset, scale;
mean = stats_(0, d) / count; mean = stats_(0, d) / count;
double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20; double var = (stats_(1, d) / count) - mean * mean, floor = 1.0e-20;
@ -111,7 +114,7 @@ void CMVN::Compute(VectorBase<BaseFloat>* feats) const {
for (int32 d_skip = d; d_skip < feats->Dim();) { for (int32 d_skip = d; d_skip < feats->Dim();) {
norm(0, d_skip) = offset; norm(0, d_skip) = offset;
norm(1, d_skip) = scale; norm(1, d_skip) = scale;
d_skip = d_skip + dim; d_skip = d_skip + dim_;
} }
} }
// Apply the normalization. // Apply the normalization.

@ -30,8 +30,11 @@ DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");
using namespace boost::json; // from <boost/json.hpp> using namespace boost::json; // from <boost/json.hpp>
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
LOG(INFO) << "cmvn josn path: " << FLAGS_json_file; LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;
@ -44,13 +47,13 @@ int main(int argc, char* argv[]) {
for (auto obj : value.as_object()) { for (auto obj : value.as_object()) {
if (obj.key() == "mean_stat") { if (obj.key() == "mean_stat") {
LOG(INFO) << "mean_stat:" << obj.value(); VLOG(2) << "mean_stat:" << obj.value();
} }
if (obj.key() == "var_stat") { if (obj.key() == "var_stat") {
LOG(INFO) << "var_stat: " << obj.value(); VLOG(2) << "var_stat: " << obj.value();
} }
if (obj.key() == "frame_num") { if (obj.key() == "frame_num") {
LOG(INFO) << "frame_num: " << obj.value(); VLOG(2) << "frame_num: " << obj.value();
} }
} }
@ -76,7 +79,7 @@ int main(int argc, char* argv[]) {
cmvn_stats(1, idx) = var_stat_vec[idx]; cmvn_stats(1, idx) = var_stat_vec[idx];
} }
cmvn_stats(0, mean_size) = frame_num; cmvn_stats(0, mean_size) = frame_num;
LOG(INFO) << cmvn_stats; VLOG(2) << cmvn_stats;
kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary); kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path; LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;

@ -16,29 +16,36 @@
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include "frontend/audio/audio_cache.h" #include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "frontend/audio/fbank.h" #include "frontend/audio/fbank.h"
#include "frontend/audio/feature_cache.h" #include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h" #include "frontend/audio/frontend_itf.h"
#include "frontend/audio/normalizer.h" #include "frontend/audio/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path"); DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
DEFINE_string(cmvn_file, "", "read cmvn"); DEFINE_string(cmvn_file, "", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
DEFINE_int32(num_bins, 161, "fbank num bins"); DEFINE_int32(num_bins, 161, "fbank num bins");
DEFINE_int32(sample_rate, 16000, "sampe rate: 16k, 8k.");
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
CHECK_GT(FLAGS_wav_rspecifier.size(), 0);
CHECK_GT(FLAGS_feature_wspecifier.size(), 0);
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader( kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier); FLAGS_wav_rspecifier);
kaldi::SequentialTableReader<kaldi::WaveInfoHolder> wav_info_reader(
FLAGS_wav_rspecifier);
kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier); kaldi::BaseFloatMatrixWriter feat_writer(FLAGS_feature_wspecifier);
int32 num_done = 0, num_err = 0; int32 num_done = 0, num_err = 0;
@ -54,6 +61,10 @@ int main(int argc, char* argv[]) {
opt.frame_opts.frame_shift_ms = 10; opt.frame_opts.frame_shift_ms = 10;
opt.mel_opts.num_bins = FLAGS_num_bins; opt.mel_opts.num_bins = FLAGS_num_bins;
opt.frame_opts.dither = 0.0; opt.frame_opts.dither = 0.0;
LOG(INFO) << "frame_length_ms: " << opt.frame_opts.frame_length_ms;
LOG(INFO) << "frame_shift_ms: " << opt.frame_opts.frame_shift_ms;
LOG(INFO) << "num_bins: " << opt.mel_opts.num_bins;
LOG(INFO) << "dither: " << opt.frame_opts.dither;
std::unique_ptr<ppspeech::FrontendInterface> fbank( std::unique_ptr<ppspeech::FrontendInterface> fbank(
new ppspeech::Fbank(opt, std::move(data_source))); new ppspeech::Fbank(opt, std::move(data_source)));
@ -61,53 +72,76 @@ int main(int argc, char* argv[]) {
std::unique_ptr<ppspeech::FrontendInterface> cmvn( std::unique_ptr<ppspeech::FrontendInterface> cmvn(
new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank))); new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank)));
ppspeech::FeatureCacheOptions feat_cache_opts;
// the feature cache output feature chunk by chunk. // the feature cache output feature chunk by chunk.
ppspeech::FeatureCacheOptions feat_cache_opts;
ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn));
LOG(INFO) << "fbank: " << true; LOG(INFO) << "fbank: " << true;
LOG(INFO) << "feat dim: " << feature_cache.Dim(); LOG(INFO) << "feat dim: " << feature_cache.Dim();
int sample_rate = 16000;
float streaming_chunk = FLAGS_streaming_chunk; float streaming_chunk = FLAGS_streaming_chunk;
int chunk_sample_size = streaming_chunk * sample_rate; int chunk_sample_size = streaming_chunk * FLAGS_sample_rate;
LOG(INFO) << "sr: " << sample_rate; LOG(INFO) << "sr: " << FLAGS_sample_rate;
LOG(INFO) << "chunk size (s): " << streaming_chunk; LOG(INFO) << "chunk size (sec): " << streaming_chunk;
LOG(INFO) << "chunk size (sample): " << chunk_sample_size; LOG(INFO) << "chunk size (sample): " << chunk_sample_size;
for (; !wav_reader.Done(); wav_reader.Next()) { for (; !wav_reader.Done() && !wav_info_reader.Done();
std::string utt = wav_reader.Key(); wav_reader.Next(), wav_info_reader.Next()) {
const std::string& utt = wav_reader.Key();
const kaldi::WaveData& wave_data = wav_reader.Value(); const kaldi::WaveData& wave_data = wav_reader.Value();
LOG(INFO) << "process utt: " << utt;
const std::string& utt2 = wav_info_reader.Key();
const kaldi::WaveInfo& wave_info = wav_info_reader.Value();
CHECK(utt == utt2)
<< "wav reader and wav info reader using diff rspecifier!!!";
LOG(INFO) << "utt: " << utt;
LOG(INFO) << "samples: " << wave_info.SampleCount();
LOG(INFO) << "dur: " << wave_info.Duration() << " sec";
CHECK(wave_info.SampFreq() == FLAGS_sample_rate)
<< "need " << FLAGS_sample_rate << " get " << wave_info.SampFreq();
// load first channel wav
int32 this_channel = 0; int32 this_channel = 0;
kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(), kaldi::SubVector<kaldi::BaseFloat> waveform(wave_data.Data(),
this_channel); this_channel);
int tot_samples = waveform.Dim();
LOG(INFO) << "wav len (sample): " << tot_samples;
// compute feat chunk by chunk
int tot_samples = waveform.Dim();
int sample_offset = 0; int sample_offset = 0;
std::vector<kaldi::Vector<BaseFloat>> feats; std::vector<kaldi::Vector<BaseFloat>> feats;
int feature_rows = 0; int feature_rows = 0;
while (sample_offset < tot_samples) { while (sample_offset < tot_samples) {
// cur chunk size
int cur_chunk_size = int cur_chunk_size =
std::min(chunk_sample_size, tot_samples - sample_offset); std::min(chunk_sample_size, tot_samples - sample_offset);
// get chunk wav
kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size); kaldi::Vector<kaldi::BaseFloat> wav_chunk(cur_chunk_size);
for (int i = 0; i < cur_chunk_size; ++i) { for (int i = 0; i < cur_chunk_size; ++i) {
wav_chunk(i) = waveform(sample_offset + i); wav_chunk(i) = waveform(sample_offset + i);
} }
kaldi::Vector<BaseFloat> features; // compute feat
feature_cache.Accept(wav_chunk); feature_cache.Accept(wav_chunk);
// send finish signal
if (cur_chunk_size < chunk_sample_size) { if (cur_chunk_size < chunk_sample_size) {
feature_cache.SetFinished(); feature_cache.SetFinished();
} }
// read feat
kaldi::Vector<BaseFloat> features;
bool flag = true; bool flag = true;
do { do {
flag = feature_cache.Read(&features); flag = feature_cache.Read(&features);
if (flag && features.Dim() != 0) {
feats.push_back(features); feats.push_back(features);
feature_rows += features.Dim() / feature_cache.Dim(); feature_rows += features.Dim() / feature_cache.Dim();
}
} while (flag == true && features.Dim() != 0); } while (flag == true && features.Dim() != 0);
// forward offset
sample_offset += cur_chunk_size; sample_offset += cur_chunk_size;
} }
@ -125,14 +159,20 @@ int main(int argc, char* argv[]) {
++cur_idx; ++cur_idx;
} }
} }
LOG(INFO) << "feat shape: " << features.NumRows() << " , "
<< features.NumCols();
feat_writer.Write(utt, features); feat_writer.Write(utt, features);
// reset frontend pipeline state
feature_cache.Reset(); feature_cache.Reset();
if (num_done % 50 == 0 && num_done != 0) if (num_done % 50 == 0 && num_done != 0)
KALDI_VLOG(2) << "Processed " << num_done << " utterances"; VLOG(2) << "Processed " << num_done << " utterances";
num_done++; num_done++;
} }
KALDI_LOG << "Done " << num_done << " utterances, " << num_err
LOG(INFO) << "Done " << num_done << " utterances, " << num_err
<< " with errors."; << " with errors.";
return (num_done != 0 ? 0 : 1); return (num_done != 0 ? 0 : 1);
} }

@ -14,16 +14,15 @@
#include "base/flags.h" #include "base/flags.h"
#include "base/log.h" #include "base/log.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
#include "frontend/audio/audio_cache.h" #include "frontend/audio/audio_cache.h"
#include "frontend/audio/data_cache.h" #include "frontend/audio/data_cache.h"
#include "frontend/audio/feature_cache.h" #include "frontend/audio/feature_cache.h"
#include "frontend/audio/frontend_itf.h" #include "frontend/audio/frontend_itf.h"
#include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h" #include "frontend/audio/normalizer.h"
#include "kaldi/feat/wave-reader.h"
#include "kaldi/util/kaldi-io.h"
#include "kaldi/util/table-types.h"
DEFINE_string(wav_rspecifier, "", "test wav scp path"); DEFINE_string(wav_rspecifier, "", "test wav scp path");
DEFINE_string(feature_wspecifier, "", "output feats wspecifier"); DEFINE_string(feature_wspecifier, "", "output feats wspecifier");
@ -31,8 +30,11 @@ DEFINE_string(cmvn_file, "./cmvn.ark", "read cmvn");
DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size");
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
gflags::SetUsageMessage("Usage:");
gflags::ParseCommandLineFlags(&argc, &argv, false); gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]); google::InitGoogleLogging(argv[0]);
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;
kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader( kaldi::SequentialTableReader<kaldi::WaveHolder> wav_reader(
FLAGS_wav_rspecifier); FLAGS_wav_rspecifier);

@ -27,7 +27,7 @@ namespace ppspeech {
// pre-recorded audio/feature // pre-recorded audio/feature
class DataCache : public FrontendInterface { class DataCache : public FrontendInterface {
public: public:
explicit DataCache() { finished_ = false; } DataCache() { finished_ = false; }
// accept waves/feats // accept waves/feats
virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) { virtual void Accept(const kaldi::VectorBase<kaldi::BaseFloat>& inputs) {
@ -56,4 +56,4 @@ class DataCache : public FrontendInterface {
DISALLOW_COPY_AND_ASSIGN(DataCache); DISALLOW_COPY_AND_ASSIGN(DataCache);
}; };
} } // namespace ppspeech

@ -14,17 +14,18 @@
#include "frontend/audio/db_norm.h" #include "frontend/audio/db_norm.h"
#include "kaldi/feat/cmvn.h" #include "kaldi/feat/cmvn.h"
#include "kaldi/util/kaldi-io.h" #include "kaldi/util/kaldi-io.h"
namespace ppspeech { namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr; using std::unique_ptr;
using std::vector;
DecibelNormalizer::DecibelNormalizer( DecibelNormalizer::DecibelNormalizer(
const DecibelNormalizerOptions& opts, const DecibelNormalizerOptions& opts,

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "frontend/audio/fbank.h" #include "frontend/audio/fbank.h"
#include "kaldi/base/kaldi-math.h" #include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h" #include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h" #include "kaldi/feat/feature-functions.h"
@ -20,12 +21,12 @@
namespace ppspeech { namespace ppspeech {
using kaldi::int32;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Vector; using kaldi::int32;
using kaldi::Matrix;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase; using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector; using std::vector;
FbankComputer::FbankComputer(const Options& opts) FbankComputer::FbankComputer(const Options& opts)

@ -16,12 +16,12 @@
namespace ppspeech { namespace ppspeech {
using kaldi::Vector;
using kaldi::VectorBase;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using std::vector;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase;
using std::unique_ptr; using std::unique_ptr;
using std::vector;
FeatureCache::FeatureCache(FeatureCacheOptions opts, FeatureCache::FeatureCache(FeatureCacheOptions opts,
unique_ptr<FrontendInterface> base_extractor) { unique_ptr<FrontendInterface> base_extractor) {
@ -73,6 +73,9 @@ bool FeatureCache::Compute() {
if (result == false || feature.Dim() == 0) return false; if (result == false || feature.Dim() == 0) return false;
int32 num_chunk = feature.Dim() / dim_; int32 num_chunk = feature.Dim() / dim_;
nframe_ += num_chunk;
VLOG(1) << "nframe computed: " << nframe_;
for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) {
int32 start = chunk_idx * dim_; int32 start = chunk_idx * dim_;
Vector<BaseFloat> feature_chunk(dim_); Vector<BaseFloat> feature_chunk(dim_);

@ -41,21 +41,24 @@ class FeatureCache : public FrontendInterface {
virtual size_t Dim() const { return dim_; } virtual size_t Dim() const { return dim_; }
virtual void SetFinished() { virtual void SetFinished() {
LOG(INFO) << "set finished";
// std::unique_lock<std::mutex> lock(mutex_); // std::unique_lock<std::mutex> lock(mutex_);
base_extractor_->SetFinished(); base_extractor_->SetFinished();
LOG(INFO) << "set finished";
// read the last chunk data // read the last chunk data
Compute(); Compute();
// ready_feed_condition_.notify_one(); // ready_feed_condition_.notify_one();
LOG(INFO) << "compute last feats done.";
} }
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { void Reset() override {
std::queue<kaldi::Vector<BaseFloat>> empty;
std::swap(cache_, empty);
nframe_ = 0;
base_extractor_->Reset(); base_extractor_->Reset();
while (!cache_.empty()) { VLOG(1) << "feature cache reset: cache size: " << cache_.size();
cache_.pop();
}
} }
private: private:
@ -74,6 +77,7 @@ class FeatureCache : public FrontendInterface {
std::condition_variable ready_feed_condition_; std::condition_variable ready_feed_condition_;
std::condition_variable ready_read_condition_; std::condition_variable ready_read_condition_;
int32 nframe_; // num of feature computed
DISALLOW_COPY_AND_ASSIGN(FeatureCache); DISALLOW_COPY_AND_ASSIGN(FeatureCache);
}; };

@ -18,7 +18,8 @@ namespace ppspeech {
using std::unique_ptr; using std::unique_ptr;
FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) { FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts)
: opts_(opts) {
unique_ptr<FrontendInterface> data_source( unique_ptr<FrontendInterface> data_source(
new ppspeech::AudioCache(1000 * kint16max, opts.to_float32)); new ppspeech::AudioCache(1000 * kint16max, opts.to_float32));
@ -32,6 +33,7 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
opts.linear_spectrogram_opts, std::move(data_source))); opts.linear_spectrogram_opts, std::move(data_source)));
} }
CHECK_NE(opts.cmvn_file, "");
unique_ptr<FrontendInterface> cmvn( unique_ptr<FrontendInterface> cmvn(
new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature)));
@ -42,4 +44,4 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) {
new ppspeech::Assembler(opts.assembler_opts, std::move(cache))); new ppspeech::Assembler(opts.assembler_opts, std::move(cache)));
} }
} // ppspeech } // namespace ppspeech

@ -25,27 +25,78 @@
#include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/linear_spectrogram.h"
#include "frontend/audio/normalizer.h" #include "frontend/audio/normalizer.h"
// feature
DECLARE_bool(use_fbank);
DECLARE_bool(fill_zero);
DECLARE_int32(num_bins);
DECLARE_string(cmvn_file);
// feature sliding window
DECLARE_int32(receptive_field_length);
DECLARE_int32(subsampling_rate);
DECLARE_int32(nnet_decoder_chunk);
namespace ppspeech { namespace ppspeech {
struct FeaturePipelineOptions { struct FeaturePipelineOptions {
std::string cmvn_file; std::string cmvn_file{};
bool to_float32; // true, only for linear feature bool to_float32{false}; // true, only for linear feature
bool use_fbank; bool use_fbank{true};
LinearSpectrogramOptions linear_spectrogram_opts; LinearSpectrogramOptions linear_spectrogram_opts{};
kaldi::FbankOptions fbank_opts; kaldi::FbankOptions fbank_opts{};
FeatureCacheOptions feature_cache_opts; FeatureCacheOptions feature_cache_opts{};
AssemblerOptions assembler_opts; AssemblerOptions assembler_opts{};
FeaturePipelineOptions() static FeaturePipelineOptions InitFromFlags() {
: cmvn_file(""), FeaturePipelineOptions opts;
to_float32(false), // true, only for linear feature opts.cmvn_file = FLAGS_cmvn_file;
use_fbank(true), LOG(INFO) << "cmvn file: " << opts.cmvn_file;
linear_spectrogram_opts(),
fbank_opts(), // frame options
feature_cache_opts(), kaldi::FrameExtractionOptions frame_opts;
assembler_opts() {} frame_opts.dither = 0.0;
LOG(INFO) << "dither: " << frame_opts.dither;
frame_opts.frame_shift_ms = 10;
LOG(INFO) << "frame shift ms: " << frame_opts.frame_shift_ms;
opts.use_fbank = FLAGS_use_fbank;
LOG(INFO) << "feature type: " << (opts.use_fbank ? "fbank" : "linear");
if (opts.use_fbank) {
opts.to_float32 = false;
frame_opts.window_type = "povey";
frame_opts.frame_length_ms = 25;
opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins;
LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins;
opts.fbank_opts.frame_opts = frame_opts;
} else {
opts.to_float32 = true;
frame_opts.remove_dc_offset = false;
frame_opts.frame_length_ms = 20;
frame_opts.window_type = "hanning";
frame_opts.preemph_coeff = 0.0;
opts.linear_spectrogram_opts.frame_opts = frame_opts;
}
LOG(INFO) << "frame length ms: " << frame_opts.frame_length_ms;
// assembler opts
opts.assembler_opts.subsampling_rate = FLAGS_subsampling_rate;
opts.assembler_opts.receptive_filed_length =
FLAGS_receptive_field_length;
opts.assembler_opts.nnet_decoder_chunk = FLAGS_nnet_decoder_chunk;
opts.assembler_opts.fill_zero = FLAGS_fill_zero;
LOG(INFO) << "subsampling rate: "
<< opts.assembler_opts.subsampling_rate;
LOG(INFO) << "nnet receptive filed length: "
<< opts.assembler_opts.receptive_filed_length;
LOG(INFO) << "nnet chunk size: "
<< opts.assembler_opts.nnet_decoder_chunk;
LOG(INFO) << "frontend fill zeros: " << opts.assembler_opts.fill_zero;
return opts;
}
}; };
class FeaturePipeline : public FrontendInterface { class FeaturePipeline : public FrontendInterface {
public: public:
explicit FeaturePipeline(const FeaturePipelineOptions& opts); explicit FeaturePipeline(const FeaturePipelineOptions& opts);
@ -60,7 +111,21 @@ class FeaturePipeline : public FrontendInterface {
virtual bool IsFinished() const { return base_extractor_->IsFinished(); } virtual bool IsFinished() const { return base_extractor_->IsFinished(); }
virtual void Reset() { base_extractor_->Reset(); } virtual void Reset() { base_extractor_->Reset(); }
const FeaturePipelineOptions& Config() { return opts_; }
const BaseFloat FrameShift() const {
return opts_.fbank_opts.frame_opts.frame_shift_ms;
}
const BaseFloat FrameLength() const {
return opts_.fbank_opts.frame_opts.frame_length_ms;
}
const BaseFloat SampleRate() const {
return opts_.fbank_opts.frame_opts.samp_freq;
}
private: private:
FeaturePipelineOptions opts_;
std::unique_ptr<FrontendInterface> base_extractor_; std::unique_ptr<FrontendInterface> base_extractor_;
}; };
}
} // namespace ppspeech

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "frontend/audio/linear_spectrogram.h" #include "frontend/audio/linear_spectrogram.h"
#include "kaldi/base/kaldi-math.h" #include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h" #include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h" #include "kaldi/feat/feature-functions.h"
@ -20,12 +21,12 @@
namespace ppspeech { namespace ppspeech {
using kaldi::int32;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Vector; using kaldi::int32;
using kaldi::Matrix;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase; using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector; using std::vector;
LinearSpectrogramComputer::LinearSpectrogramComputer(const Options& opts) LinearSpectrogramComputer::LinearSpectrogramComputer(const Options& opts)

@ -14,6 +14,7 @@
#include "frontend/audio/mfcc.h" #include "frontend/audio/mfcc.h"
#include "kaldi/base/kaldi-math.h" #include "kaldi/base/kaldi-math.h"
#include "kaldi/feat/feature-common.h" #include "kaldi/feat/feature-common.h"
#include "kaldi/feat/feature-functions.h" #include "kaldi/feat/feature-functions.h"
@ -21,12 +22,12 @@
namespace ppspeech { namespace ppspeech {
using kaldi::int32;
using kaldi::BaseFloat; using kaldi::BaseFloat;
using kaldi::Vector; using kaldi::int32;
using kaldi::Matrix;
using kaldi::SubVector; using kaldi::SubVector;
using kaldi::Vector;
using kaldi::VectorBase; using kaldi::VectorBase;
using kaldi::Matrix;
using std::vector; using std::vector;
Mfcc::Mfcc(const MfccOptions& opts, Mfcc::Mfcc(const MfccOptions& opts,

@ -14,7 +14,6 @@
#pragma once #pragma once
#include "kaldi/feat/feature-mfcc.h"
#include "kaldi/feat/feature-mfcc.h" #include "kaldi/feat/feature-mfcc.h"
#include "kaldi/matrix/kaldi-vector.h" #include "kaldi/matrix/kaldi-vector.h"

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save