diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index a18c454c..1ff47330 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -27,4 +27,4 @@ git commit -m "xxxxxx, test=doc" 1. 虽然跳过了 CI,但是还要先排队排到才能跳过,所以非自己方向看到 pending 不要着急 🤣 2. 在 `git commit --amend` 的时候才加 `test=xxx` 可能不太有效 3. 一个 pr 多次提交 commit 注意每次都要加 `test=xxx`,因为每个 commit 都会触发 CI -4. 删除 python 环境中已经安装好的的 paddlespeech,否则可能会影响 import paddlespeech 的顺序 +4. 删除 python 环境中已经安装好的 paddlespeech,否则可能会影响 import paddlespeech 的顺序 diff --git a/.mergify.yml b/.mergify.yml index 5cb1f486..0f182b51 100644 --- a/.mergify.yml +++ b/.mergify.yml @@ -136,7 +136,7 @@ pull_request_rules: add: ["Docker"] - name: "auto add label=Deployment" conditions: - - files~=^speechx/ + - files~=^runtime/ actions: label: add: ["Deployment"] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 53fc6ba0..6afa7c9c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,8 +3,12 @@ repos: rev: v0.16.0 hooks: - id: yapf - files: \.py$ - exclude: (?=third_party).*(\.py)$ + name: yapf + language: python + entry: yapf + args: [-i, -vv] + types: [python] + exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$ - repo: https://github.com/pre-commit/pre-commit-hooks rev: a11d9314b22d8f8c7556443875b731ef05965464 @@ -31,7 +35,7 @@ repos: - --ignore=E501,E228,E226,E261,E266,E128,E402,W503 - --builtins=G,request - --jobs=1 - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$ + exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|third_party).*(\.cpp|\.cc|\.h\.hpp|\.py)$ - repo : https://github.com/Lucas-C/pre-commit-hooks rev: v1.0.1 @@ -53,16 +57,16 @@ repos: entry: bash .pre-commit-hooks/clang-format.hook -i language: system files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party/kaldi-native-fbank/csrc|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ + exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders|runtime/engine/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ - id: cpplint name: cpplint description: Static code analysis of C/C++ files language: python files: \.(h\+\+|h|hh|hxx|hpp|cuh|c|cc|cpp|cu|c\+\+|cxx|tpp|txx)$ - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|audio/paddleaudio/third_party/kaldi-native-fbank/csrc|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ + exclude: (?=runtime/engine/kaldi|runtime/engine/common/matrix|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders|runtime/engine/common/utils).*(\.cpp|\.cc|\.h|\.hpp|\.py)$ entry: cpplint --filter=-build,-whitespace,+whitespace/comma,-whitespace/indent - repo: https://github.com/asottile/reorder_python_imports rev: v2.4.0 hooks: - id: reorder-python-imports - exclude: (?=speechx/speechx/kaldi|audio/paddleaudio/src|speechx/patch|speechx/tools/fstbin|speechx/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h\.hpp|\.py)$ + exclude: (?=runtime/engine/kaldi|audio/paddleaudio/src|runtime/patch|runtime/tools/fstbin|runtime/tools/lmbin|third_party/ctc_decoders).*(\.cpp|\.cc|\.h\.hpp|\.py)$ diff --git a/README.md b/README.md index 702e4187..fdc5981d 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision ### Recent Update - 👑 2023.04.25: Add [AMP for U2 conformer](https://github.com/PaddlePaddle/PaddleSpeech/pull/3167). +- 🔥 2023.04.06: Add [subtitle file (.srt format) generation example](./demos/streaming_asr_server). - 🔥 2023.03.14: Add SVS(Singing Voice Synthesis) examples with Opencpop dataset, including [DiffSinger](./examples/opencpop/svs1)、[PWGAN](./examples/opencpop/voc1) and [HiFiGAN](./examples/opencpop/voc5), the effect is continuously optimized. - 👑 2023.03.09: Add [Wav2vec2ASR-zh](./examples/aishell/asr3). - 🎉 2023.03.07: Add [TTS ARM Linux C++ Demo (with C++ Chinese Text Frontend)](./demos/TTSArmLinux). @@ -193,7 +194,7 @@ Via the easy-to-use, efficient, flexible and scalable implementation, our vision - 👑 2022.11.18: Add [Whisper CLI and Demos](https://github.com/PaddlePaddle/PaddleSpeech/pull/2640), support multi language recognition and translation. - 🔥 2022.11.18: Add [Wav2vec2 CLI and Demos](./demos/speech_ssl), Support ASR and Feature Extraction. - 🎉 2022.11.17: Add [male voice for TTS](https://github.com/PaddlePaddle/PaddleSpeech/pull/2660). -- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](./speechx/examples/u2pp_ol/wenetspeech). +- 🔥 2022.11.07: Add [U2/U2++ C++ High Performance Streaming ASR Deployment](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/runtime/examples/u2pp_ol/wenetspeech). - 👑 2022.11.01: Add [Adversarial Loss](https://arxiv.org/pdf/1907.04448.pdf) for [Chinese English mixed TTS](./examples/zh_en_tts/tts3). - 🔥 2022.10.26: Add [Prosody Prediction](./examples/other/rhy) for TTS. - 🎉 2022.10.21: Add [SSML](https://github.com/PaddlePaddle/PaddleSpeech/discussions/2538) for TTS Chinese Text Frontend. diff --git a/README_cn.md b/README_cn.md index 46bef1f6..e5e18f0c 100644 --- a/README_cn.md +++ b/README_cn.md @@ -184,6 +184,7 @@ ### 近期更新 - 👑 2023.04.25: 新增 [U2 conformer 的 AMP 训练](https://github.com/PaddlePaddle/PaddleSpeech/pull/3167). +- 👑 2023.04.06: 新增 [srt格式字幕生成功能](./demos/streaming_asr_server)。 - 🔥 2023.03.14: 新增基于 Opencpop 数据集的 SVS (歌唱合成) 示例,包含 [DiffSinger](./examples/opencpop/svs1)、[PWGAN](./examples/opencpop/voc1) 和 [HiFiGAN](./examples/opencpop/voc5),效果持续优化中。 - 👑 2023.03.09: 新增 [Wav2vec2ASR-zh](./examples/aishell/asr3)。 - 🎉 2023.03.07: 新增 [TTS ARM Linux C++ 部署示例 (包含 C++ 中文文本前端模块)](./demos/TTSArmLinux)。 diff --git a/audio/paddleaudio/backends/soundfile_backend.py b/audio/paddleaudio/backends/soundfile_backend.py index ae7b5b52..9195ea09 100644 --- a/audio/paddleaudio/backends/soundfile_backend.py +++ b/audio/paddleaudio/backends/soundfile_backend.py @@ -191,7 +191,7 @@ def soundfile_save(y: np.ndarray, sr: int, file: os.PathLike) -> None: if sr <= 0: raise ParameterError( - f'Sample rate should be larger than 0, recieved sr = {sr}') + f'Sample rate should be larger than 0, received sr = {sr}') if y.dtype not in ['int16', 'int8']: warnings.warn( diff --git a/dataset/aidatatang_200zh/aidatatang_200zh.py b/dataset/aidatatang_200zh/aidatatang_200zh.py index 85f478c2..3b706c49 100644 --- a/dataset/aidatatang_200zh/aidatatang_200zh.py +++ b/dataset/aidatatang_200zh/aidatatang_200zh.py @@ -18,139 +18,7 @@ Manifest file is a json-format file with each line containing the meta data (i.e. audio filepath, transcript and audio duration) of each audio file in the data set. """ -import argparse -import codecs -import json -import os -from pathlib import Path - -import soundfile - -from utils.utility import download -from utils.utility import unpack - -DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') - -URL_ROOT = 'http://www.openslr.org/resources/62' -# URL_ROOT = 'https://openslr.magicdatatech.com/resources/62' -DATA_URL = URL_ROOT + '/aidatatang_200zh.tgz' -MD5_DATA = '6e0f4f39cd5f667a7ee53c397c8d0949' - -parser = argparse.ArgumentParser(description=__doc__) -parser.add_argument( - "--target_dir", - default=DATA_HOME + "/aidatatang_200zh", - type=str, - help="Directory to save the dataset. (default: %(default)s)") -parser.add_argument( - "--manifest_prefix", - default="manifest", - type=str, - help="Filepath prefix for output manifests. (default: %(default)s)") -args = parser.parse_args() - - -def create_manifest(data_dir, manifest_path_prefix): - print("Creating manifest %s ..." % manifest_path_prefix) - json_lines = [] - transcript_path = os.path.join(data_dir, 'transcript', - 'aidatatang_200_zh_transcript.txt') - transcript_dict = {} - for line in codecs.open(transcript_path, 'r', 'utf-8'): - line = line.strip() - if line == '': - continue - audio_id, text = line.split(' ', 1) - # remove withespace, charactor text - text = ''.join(text.split()) - transcript_dict[audio_id] = text - - data_types = ['train', 'dev', 'test'] - for dtype in data_types: - del json_lines[:] - total_sec = 0.0 - total_text = 0.0 - total_num = 0 - - audio_dir = os.path.join(data_dir, 'corpus/', dtype) - for subfolder, _, filelist in sorted(os.walk(audio_dir)): - for fname in filelist: - if not fname.endswith('.wav'): - continue - - audio_path = os.path.abspath(os.path.join(subfolder, fname)) - audio_id = os.path.basename(fname)[:-4] - utt2spk = Path(audio_path).parent.name - - audio_data, samplerate = soundfile.read(audio_path) - duration = float(len(audio_data) / samplerate) - text = transcript_dict[audio_id] - json_lines.append( - json.dumps( - { - 'utt': audio_id, - 'utt2spk': str(utt2spk), - 'feat': audio_path, - 'feat_shape': (duration, ), # second - 'text': text, - }, - ensure_ascii=False)) - - total_sec += duration - total_text += len(text) - total_num += 1 - - manifest_path = manifest_path_prefix + '.' + dtype - with codecs.open(manifest_path, 'w', 'utf-8') as fout: - for line in json_lines: - fout.write(line + '\n') - - manifest_dir = os.path.dirname(manifest_path_prefix) - meta_path = os.path.join(manifest_dir, dtype) + '.meta' - with open(meta_path, 'w') as f: - print(f"{dtype}:", file=f) - print(f"{total_num} utts", file=f) - print(f"{total_sec / (60*60)} h", file=f) - print(f"{total_text} text", file=f) - print(f"{total_text / total_sec} text/sec", file=f) - print(f"{total_sec / total_num} sec/utt", file=f) - - -def prepare_dataset(url, md5sum, target_dir, manifest_path, subset): - """Download, unpack and create manifest file.""" - data_dir = os.path.join(target_dir, subset) - if not os.path.exists(data_dir): - filepath = download(url, md5sum, target_dir) - unpack(filepath, target_dir) - # unpack all audio tar files - audio_dir = os.path.join(data_dir, 'corpus') - for subfolder, dirlist, filelist in sorted(os.walk(audio_dir)): - for sub in dirlist: - print(f"unpack dir {sub}...") - for folder, _, filelist in sorted( - os.walk(os.path.join(subfolder, sub))): - for ftar in filelist: - unpack(os.path.join(folder, ftar), folder, True) - else: - print("Skip downloading and unpacking. Data already exists in %s." % - target_dir) - - create_manifest(data_dir, manifest_path) - - -def main(): - if args.target_dir.startswith('~'): - args.target_dir = os.path.expanduser(args.target_dir) - - prepare_dataset( - url=DATA_URL, - md5sum=MD5_DATA, - target_dir=args.target_dir, - manifest_path=args.manifest_prefix, - subset='aidatatang_200zh') - - print("Data download and manifest prepare done!") - +from paddlespeech.dataset.aidatatang_200zh import aidatatang_200zh_main if __name__ == '__main__': - main() + aidatatang_200zh_main() diff --git a/dataset/aishell/README.md b/dataset/aishell/README.md deleted file mode 100644 index a7dd0cf3..00000000 --- a/dataset/aishell/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# [Aishell1](http://openslr.elda.org/33/) - -This Open Source Mandarin Speech Corpus, AISHELL-ASR0009-OS1, is 178 hours long. It is a part of AISHELL-ASR0009, of which utterance contains 11 domains, including smart home, autonomous driving, and industrial production. The whole recording was put in quiet indoor environment, using 3 different devices at the same time: high fidelity microphone (44.1kHz, 16-bit,); Android-system mobile phone (16kHz, 16-bit), iOS-system mobile phone (16kHz, 16-bit). Audios in high fidelity were re-sampled to 16kHz to build AISHELL- ASR0009-OS1. 400 speakers from different accent areas in China were invited to participate in the recording. The manual transcription accuracy rate is above 95%, through professional speech annotation and strict quality inspection. The corpus is divided into training, development and testing sets. ( This database is free for academic research, not in the commerce, if without permission. ) diff --git a/dataset/aishell/aishell.py b/dataset/aishell/aishell.py index ec43104d..b3288757 100644 --- a/dataset/aishell/aishell.py +++ b/dataset/aishell/aishell.py @@ -18,143 +18,7 @@ Manifest file is a json-format file with each line containing the meta data (i.e. audio filepath, transcript and audio duration) of each audio file in the data set. """ -import argparse -import codecs -import json -import os -from pathlib import Path - -import soundfile - -from utils.utility import download -from utils.utility import unpack - -DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') - -URL_ROOT = 'http://openslr.elda.org/resources/33' -# URL_ROOT = 'https://openslr.magicdatatech.com/resources/33' -DATA_URL = URL_ROOT + '/data_aishell.tgz' -MD5_DATA = '2f494334227864a8a8fec932999db9d8' -RESOURCE_URL = URL_ROOT + '/resource_aishell.tgz' -MD5_RESOURCE = '957d480a0fcac85fc18e550756f624e5' - -parser = argparse.ArgumentParser(description=__doc__) -parser.add_argument( - "--target_dir", - default=DATA_HOME + "/Aishell", - type=str, - help="Directory to save the dataset. (default: %(default)s)") -parser.add_argument( - "--manifest_prefix", - default="manifest", - type=str, - help="Filepath prefix for output manifests. (default: %(default)s)") -args = parser.parse_args() - - -def create_manifest(data_dir, manifest_path_prefix): - print("Creating manifest %s ..." % manifest_path_prefix) - json_lines = [] - transcript_path = os.path.join(data_dir, 'transcript', - 'aishell_transcript_v0.8.txt') - transcript_dict = {} - for line in codecs.open(transcript_path, 'r', 'utf-8'): - line = line.strip() - if line == '': - continue - audio_id, text = line.split(' ', 1) - # remove withespace, charactor text - text = ''.join(text.split()) - transcript_dict[audio_id] = text - - data_types = ['train', 'dev', 'test'] - for dtype in data_types: - del json_lines[:] - total_sec = 0.0 - total_text = 0.0 - total_num = 0 - - audio_dir = os.path.join(data_dir, 'wav', dtype) - for subfolder, _, filelist in sorted(os.walk(audio_dir)): - for fname in filelist: - audio_path = os.path.abspath(os.path.join(subfolder, fname)) - audio_id = os.path.basename(fname)[:-4] - # if no transcription for audio then skipped - if audio_id not in transcript_dict: - continue - - utt2spk = Path(audio_path).parent.name - audio_data, samplerate = soundfile.read(audio_path) - duration = float(len(audio_data) / samplerate) - text = transcript_dict[audio_id] - json_lines.append( - json.dumps( - { - 'utt': audio_id, - 'utt2spk': str(utt2spk), - 'feat': audio_path, - 'feat_shape': (duration, ), # second - 'text': text - }, - ensure_ascii=False)) - - total_sec += duration - total_text += len(text) - total_num += 1 - - manifest_path = manifest_path_prefix + '.' + dtype - with codecs.open(manifest_path, 'w', 'utf-8') as fout: - for line in json_lines: - fout.write(line + '\n') - - manifest_dir = os.path.dirname(manifest_path_prefix) - meta_path = os.path.join(manifest_dir, dtype) + '.meta' - with open(meta_path, 'w') as f: - print(f"{dtype}:", file=f) - print(f"{total_num} utts", file=f) - print(f"{total_sec / (60*60)} h", file=f) - print(f"{total_text} text", file=f) - print(f"{total_text / total_sec} text/sec", file=f) - print(f"{total_sec / total_num} sec/utt", file=f) - - -def prepare_dataset(url, md5sum, target_dir, manifest_path=None): - """Download, unpack and create manifest file.""" - data_dir = os.path.join(target_dir, 'data_aishell') - if not os.path.exists(data_dir): - filepath = download(url, md5sum, target_dir) - unpack(filepath, target_dir) - # unpack all audio tar files - audio_dir = os.path.join(data_dir, 'wav') - for subfolder, _, filelist in sorted(os.walk(audio_dir)): - for ftar in filelist: - unpack(os.path.join(subfolder, ftar), subfolder, True) - else: - print("Skip downloading and unpacking. Data already exists in %s." % - target_dir) - - if manifest_path: - create_manifest(data_dir, manifest_path) - - -def main(): - if args.target_dir.startswith('~'): - args.target_dir = os.path.expanduser(args.target_dir) - - prepare_dataset( - url=DATA_URL, - md5sum=MD5_DATA, - target_dir=args.target_dir, - manifest_path=args.manifest_prefix) - - prepare_dataset( - url=RESOURCE_URL, - md5sum=MD5_RESOURCE, - target_dir=args.target_dir, - manifest_path=None) - - print("Data download and manifest prepare done!") - +from paddlespeech.dataset.aishell import aishell_main if __name__ == '__main__': - main() + aishell_main() diff --git a/dataset/librispeech/librispeech.py b/dataset/librispeech/librispeech.py index 2d6f1763..44567b0c 100644 --- a/dataset/librispeech/librispeech.py +++ b/dataset/librispeech/librispeech.py @@ -28,8 +28,8 @@ from multiprocessing.pool import Pool import distutils.util import soundfile -from utils.utility import download -from utils.utility import unpack +from paddlespeech.dataset.download import download +from paddlespeech.dataset.download import unpack URL_ROOT = "http://openslr.elda.org/resources/12" #URL_ROOT = "https://openslr.magicdatatech.com/resources/12" diff --git a/dataset/mini_librispeech/mini_librispeech.py b/dataset/mini_librispeech/mini_librispeech.py index 0eb80bf8..24bd98d8 100644 --- a/dataset/mini_librispeech/mini_librispeech.py +++ b/dataset/mini_librispeech/mini_librispeech.py @@ -27,8 +27,8 @@ from multiprocessing.pool import Pool import soundfile -from utils.utility import download -from utils.utility import unpack +from paddlespeech.dataset.download import download +from paddlespeech.dataset.download import unpack URL_ROOT = "http://openslr.elda.org/resources/31" URL_TRAIN_CLEAN = URL_ROOT + "/train-clean-5.tar.gz" diff --git a/dataset/musan/musan.py b/dataset/musan/musan.py index ae3430b2..85d986e8 100644 --- a/dataset/musan/musan.py +++ b/dataset/musan/musan.py @@ -29,8 +29,8 @@ import os import soundfile -from utils.utility import download -from utils.utility import unpack +from paddlespeech.dataset.download import download +from paddlespeech.dataset.download import unpack DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') diff --git a/dataset/rir_noise/rir_noise.py b/dataset/rir_noise/rir_noise.py index b1d47558..b98dff72 100644 --- a/dataset/rir_noise/rir_noise.py +++ b/dataset/rir_noise/rir_noise.py @@ -29,8 +29,8 @@ import os import soundfile -from utils.utility import download -from utils.utility import unzip +from paddlespeech.dataset.download import download +from paddlespeech.dataset.download import unzip DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') diff --git a/dataset/thchs30/thchs30.py b/dataset/thchs30/thchs30.py index d41c0e17..c5c3eb7a 100644 --- a/dataset/thchs30/thchs30.py +++ b/dataset/thchs30/thchs30.py @@ -27,8 +27,8 @@ from pathlib import Path import soundfile -from utils.utility import download -from utils.utility import unpack +from paddlespeech.dataset.download import download +from paddlespeech.dataset.download import unpack DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') diff --git a/dataset/timit/timit.py b/dataset/timit/timit.py index c4a9f066..f3889d17 100644 --- a/dataset/timit/timit.py +++ b/dataset/timit/timit.py @@ -28,7 +28,7 @@ from pathlib import Path import soundfile -from utils.utility import unzip +from paddlespeech.dataset.download import unzip URL_ROOT = "" MD5_DATA = "45c68037c7fdfe063a43c851f181fb2d" diff --git a/dataset/voxceleb/voxceleb1.py b/dataset/voxceleb/voxceleb1.py index 95827f70..8d410067 100644 --- a/dataset/voxceleb/voxceleb1.py +++ b/dataset/voxceleb/voxceleb1.py @@ -31,9 +31,9 @@ from pathlib import Path import soundfile -from utils.utility import check_md5sum -from utils.utility import download -from utils.utility import unzip +from paddlespeech.dataset.download import check_md5sum +from paddlespeech.dataset.download import download +from paddlespeech.dataset.download import unzip # all the data will be download in the current data/voxceleb directory default DATA_HOME = os.path.expanduser('.') diff --git a/dataset/voxceleb/voxceleb2.py b/dataset/voxceleb/voxceleb2.py index fe9e8b9c..6df6d1f3 100644 --- a/dataset/voxceleb/voxceleb2.py +++ b/dataset/voxceleb/voxceleb2.py @@ -27,9 +27,9 @@ from pathlib import Path import soundfile -from utils.utility import check_md5sum -from utils.utility import download -from utils.utility import unzip +from paddlespeech.dataset.download import check_md5sum +from paddlespeech.dataset.download import download +from paddlespeech.dataset.download import unzip # all the data will be download in the current data/voxceleb directory default DATA_HOME = os.path.expanduser('.') diff --git a/dataset/voxforge/voxforge.py b/dataset/voxforge/voxforge.py index 373791bf..327d200b 100644 --- a/dataset/voxforge/voxforge.py +++ b/dataset/voxforge/voxforge.py @@ -28,9 +28,9 @@ import subprocess import soundfile -from utils.utility import download_multi -from utils.utility import getfile_insensitive -from utils.utility import unpack +from paddlespeech.dataset.download import download_multi +from paddlespeech.dataset.download import getfile_insensitive +from paddlespeech.dataset.download import unpack DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') diff --git a/demos/TTSAndroid/README.md b/demos/TTSAndroid/README.md index 36ff969f..36848cbe 100644 --- a/demos/TTSAndroid/README.md +++ b/demos/TTSAndroid/README.md @@ -1,6 +1,6 @@ # 语音合成 Java API Demo 使用指南 -在 Android 上实现语音合成功能,此 Demo 有很好的的易用性和开放性,如在 Demo 中跑自己训练好的模型等。 +在 Android 上实现语音合成功能,此 Demo 有很好的易用性和开放性,如在 Demo 中跑自己训练好的模型等。 本文主要介绍语音合成 Demo 运行方法。 diff --git a/demos/TTSArmLinux/front.conf b/demos/TTSArmLinux/front.conf index 04bd2d97..5960b32a 100644 --- a/demos/TTSArmLinux/front.conf +++ b/demos/TTSArmLinux/front.conf @@ -6,13 +6,13 @@ --jieba_stop_word_path=./dict/jieba/stop_words.utf8 # dict conf fastspeech2_0.4 ---seperate_tone=false +--separate_tone=false --word2phone_path=./dict/fastspeech2_nosil_baker_ckpt_0.4/word2phone_fs2.dict --phone2id_path=./dict/fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt --tone2id_path=./dict/fastspeech2_nosil_baker_ckpt_0.4/word2phone_fs2.dict # dict conf speedyspeech_0.5 -#--seperate_tone=true +#--separate_tone=true #--word2phone_path=./dict/speedyspeech_nosil_baker_ckpt_0.5/word2phone.dict #--phone2id_path=./dict/speedyspeech_nosil_baker_ckpt_0.5/phone_id_map.txt #--tone2id_path=./dict/speedyspeech_nosil_baker_ckpt_0.5/tone_id_map.txt diff --git a/demos/TTSCppFrontend/front_demo/front.conf b/demos/TTSCppFrontend/front_demo/front.conf index e9ce1c94..abff4447 100644 --- a/demos/TTSCppFrontend/front_demo/front.conf +++ b/demos/TTSCppFrontend/front_demo/front.conf @@ -6,13 +6,13 @@ --jieba_stop_word_path=./front_demo/dict/jieba/stop_words.utf8 # dict conf fastspeech2_0.4 ---seperate_tone=false +--separate_tone=false --word2phone_path=./front_demo/dict/fastspeech2_nosil_baker_ckpt_0.4/word2phone_fs2.dict --phone2id_path=./front_demo/dict/fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt --tone2id_path=./front_demo/dict/fastspeech2_nosil_baker_ckpt_0.4/word2phone_fs2.dict # dict conf speedyspeech_0.5 -#--seperate_tone=true +#--separate_tone=true #--word2phone_path=./front_demo/dict/speedyspeech_nosil_baker_ckpt_0.5/word2phone.dict #--phone2id_path=./front_demo/dict/speedyspeech_nosil_baker_ckpt_0.5/phone_id_map.txt #--tone2id_path=./front_demo/dict/speedyspeech_nosil_baker_ckpt_0.5/tone_id_map.txt diff --git a/demos/TTSCppFrontend/front_demo/front_demo.cpp b/demos/TTSCppFrontend/front_demo/front_demo.cpp index 19f16758..77f3fc72 100644 --- a/demos/TTSCppFrontend/front_demo/front_demo.cpp +++ b/demos/TTSCppFrontend/front_demo/front_demo.cpp @@ -20,7 +20,7 @@ DEFINE_string(sentence, "你好,欢迎使用语音合成服务", "Text to be synthesized"); DEFINE_string(front_conf, "./front_demo/front.conf", "Front conf file"); -// DEFINE_string(seperate_tone, "true", "If true, get phoneids and tonesid"); +// DEFINE_string(separate_tone, "true", "If true, get phoneids and tonesid"); int main(int argc, char** argv) { diff --git a/demos/TTSCppFrontend/front_demo/gentools/word2phones.py b/demos/TTSCppFrontend/front_demo/gentools/word2phones.py index 8726ee89..d9baeea9 100644 --- a/demos/TTSCppFrontend/front_demo/gentools/word2phones.py +++ b/demos/TTSCppFrontend/front_demo/gentools/word2phones.py @@ -20,7 +20,7 @@ worddict = "./dict/jieba_part.dict.utf8" newdict = "./dict/word_phones.dict" -def GenPhones(initials, finals, seperate=True): +def GenPhones(initials, finals, separate=True): phones = [] for c, v in zip(initials, finals): @@ -30,9 +30,9 @@ def GenPhones(initials, finals, seperate=True): elif c in ['zh', 'ch', 'sh', 'r']: v = re.sub('i', 'iii', v) if c: - if seperate is True: + if separate is True: phones.append(c + '0') - elif seperate is False: + elif separate is False: phones.append(c) else: print("Not sure whether phone and tone need to be separated") diff --git a/demos/TTSCppFrontend/src/front/front_interface.cpp b/demos/TTSCppFrontend/src/front/front_interface.cpp index 8bd466d2..e7b08c79 100644 --- a/demos/TTSCppFrontend/src/front/front_interface.cpp +++ b/demos/TTSCppFrontend/src/front/front_interface.cpp @@ -126,7 +126,7 @@ int FrontEngineInterface::init() { } // 生成音调字典(音调到音调id的映射) - if (_seperate_tone == "true") { + if (_separate_tone == "true") { if (0 != GenDict(_tone2id_path, &tone_id_map)) { LOG(ERROR) << "Genarate tone2id dict failed"; return -1; @@ -168,7 +168,7 @@ int FrontEngineInterface::ReadConfFile() { _jieba_stop_word_path = conf_map["jieba_stop_word_path"]; // dict path - _seperate_tone = conf_map["seperate_tone"]; + _separate_tone = conf_map["separate_tone"]; _word2phone_path = conf_map["word2phone_path"]; _phone2id_path = conf_map["phone2id_path"]; _tone2id_path = conf_map["tone2id_path"]; @@ -295,7 +295,7 @@ int FrontEngineInterface::GetWordsIds( } } } else { // 标点符号 - if (_seperate_tone == "true") { + if (_separate_tone == "true") { phone = "sp0"; // speedyspeech } else { phone = "sp"; // fastspeech2 @@ -354,7 +354,7 @@ int FrontEngineInterface::Phone2Phoneid(const std::string &phone, std::string temp_phone; for (int i = 0; i < phone_vec.size(); i++) { temp_phone = phone_vec[i]; - if (_seperate_tone == "true") { + if (_separate_tone == "true") { phoneid->push_back(atoi( (phone_id_map[temp_phone.substr(0, temp_phone.length() - 1)]) .c_str())); diff --git a/demos/TTSCppFrontend/src/front/front_interface.h b/demos/TTSCppFrontend/src/front/front_interface.h index fc33a4de..8c16859c 100644 --- a/demos/TTSCppFrontend/src/front/front_interface.h +++ b/demos/TTSCppFrontend/src/front/front_interface.h @@ -182,7 +182,7 @@ class FrontEngineInterface : public TextNormalizer { std::string _jieba_idf_path; std::string _jieba_stop_word_path; - std::string _seperate_tone; + std::string _separate_tone; std::string _word2phone_path; std::string _phone2id_path; std::string _tone2id_path; diff --git a/demos/audio_searching/src/test_audio_search.py b/demos/audio_searching/src/test_audio_search.py index cb91e156..f9ea2929 100644 --- a/demos/audio_searching/src/test_audio_search.py +++ b/demos/audio_searching/src/test_audio_search.py @@ -14,8 +14,8 @@ from audio_search import app from fastapi.testclient import TestClient -from utils.utility import download -from utils.utility import unpack +from paddlespeech.dataset.download import download +from paddlespeech.dataset.download import unpack client = TestClient(app) diff --git a/demos/audio_searching/src/test_vpr_search.py b/demos/audio_searching/src/test_vpr_search.py index 298e12eb..cc795564 100644 --- a/demos/audio_searching/src/test_vpr_search.py +++ b/demos/audio_searching/src/test_vpr_search.py @@ -14,8 +14,8 @@ from fastapi.testclient import TestClient from vpr_search import app -from utils.utility import download -from utils.utility import unpack +from paddlespeech.dataset.download import download +from paddlespeech.dataset.download import unpack client = TestClient(app) diff --git a/demos/speech_web/README.md b/demos/speech_web/README.md index 572781ab..fc1fe710 100644 --- a/demos/speech_web/README.md +++ b/demos/speech_web/README.md @@ -23,7 +23,7 @@ Paddle Speech Demo 是一个以 PaddleSpeech 的语音交互功能为主体开 + ERNIE-SAT:语言-语音跨模态大模型 ERNIE-SAT 可视化展示示例,支持个性化合成,跨语言语音合成(音频为中文则输入英文文本进行合成),语音编辑(修改音频文字中间的结果)功能。 ERNIE-SAT 更多实现细节,可以参考: + [【ERNIE-SAT with AISHELL-3 dataset】](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/ernie_sat) - + [【ERNIE-SAT with with AISHELL3 and VCTK datasets】](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3_vctk/ernie_sat) + + [【ERNIE-SAT with AISHELL3 and VCTK datasets】](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3_vctk/ernie_sat) + [【ERNIE-SAT with VCTK dataset】](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/vctk/ernie_sat) 运行效果: diff --git a/demos/speech_web/speech_server/main.py b/demos/speech_web/speech_server/main.py index 03e7e599..f4678628 100644 --- a/demos/speech_web/speech_server/main.py +++ b/demos/speech_web/speech_server/main.py @@ -260,7 +260,7 @@ async def websocket_endpoint_online(websocket: WebSocket): # and we break the loop if message['signal'] == 'start': resp = {"status": "ok", "signal": "server_ready"} - # do something at begining here + # do something at beginning here # create the instance to process the audio # connection_handler = chatbot.asr.connection_handler connection_handler = PaddleASRConnectionHanddler(engine) diff --git a/demos/streaming_asr_server/README.md b/demos/streaming_asr_server/README.md index c15d0601..31256d15 100644 --- a/demos/streaming_asr_server/README.md +++ b/demos/streaming_asr_server/README.md @@ -579,3 +579,354 @@ bash server.sh [2022-05-07 11:11:18,915] [ INFO] - audio duration: 4.9968125, elapsed time: 15.928460597991943, RTF=3.187724293835709 [2022-05-07 11:11:18,916] [ INFO] - asr websocket client finished : 我认为跑步最重要的就是给我带来了身体健康 ``` + +## Generate corresponding subtitle (.srt format) from audio file (.wav format or.mp3 format) + +By default, each server is deployed on the 'CPU' device and speech recognition and punctuation prediction can be deployed on different 'GPU' by modifying the' device 'parameter in the service configuration file respectively. + +We use `streaming_ asr_server.py` and `punc_server.py` two services to lanuch streaming speech recognition and punctuation prediction services respectively. And the `websocket_client_srt.py` script can be used to call streaming speech recognition and punctuation prediction services at the same time, and will generate the corresponding subtitle (.srt format). + +**need to install ffmpeg before running this script** + +**You should at the directory of `.../demos/streaming_asr_server/`** + +### 1. Start two server + +```bash +Note: streaming speech recognition and punctuation prediction are configured on different graphics cards through configuration files +paddlespeech_server start --config_file ./conf/ws_conformer_wenetspeech_application.yaml +``` + +Open another terminal run the following commands: +```bash +paddlespeech_server start --config_file conf/punc_application.yaml +``` + +### 2. Call client + + ```bash + python3 local/websocket_client_srt.py --server_ip 127.0.0.1 --port 8090 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ../../data/认知.mp3 + ``` + Output: + ```text + [2023-03-30 23:26:13,991] [ INFO] - Start to do streaming asr client +[2023-03-30 23:26:13,994] [ INFO] - asr websocket client start +[2023-03-30 23:26:13,994] [ INFO] - endpoint: http://127.0.0.1:8190/paddlespeech/text +[2023-03-30 23:26:13,994] [ INFO] - endpoint: ws://127.0.0.1:8090/paddlespeech/asr/streaming +[2023-03-30 23:26:14,475] [ INFO] - /home/fxb/PaddleSpeech-develop/data/认知.mp3 converted to /home/fxb/PaddleSpeech-develop/data/认知.wav +[2023-03-30 23:26:14,476] [ INFO] - start to process the wavscp: /home/fxb/PaddleSpeech-develop/data/认知.wav +[2023-03-30 23:26:14,515] [ INFO] - client receive msg={"status": "ok", "signal": "server_ready"} +[2023-03-30 23:26:14,533] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,545] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,556] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,572] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,588] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,600] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,613] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,626] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:15,122] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,135] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,154] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,163] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,175] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,185] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,196] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,637] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,648] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,657] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,666] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,676] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,683] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,691] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,703] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:16,146] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,159] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,167] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,177] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,187] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,197] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,210] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,694] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,704] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,713] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,725] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,737] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,749] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,759] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,770] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:17,279] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,302] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,316] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,332] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,343] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,358] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,373] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,958] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:17,971] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:17,987] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:18,000] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:18,017] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:18,028] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:18,038] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:18,049] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:18,653] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:18,689] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:18,701] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:18,712] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:18,723] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:18,750] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:18,767] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:19,295] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,307] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,323] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,332] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,342] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,349] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,373] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,389] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:20,046] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,055] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,067] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,076] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,094] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,124] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,135] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,732] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,742] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,757] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,770] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,782] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,798] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,815] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,834] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:21,390] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:21,405] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:21,416] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:21,428] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:21,448] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:21,459] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:21,473] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:22,065] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,085] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,110] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,118] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,137] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,144] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,154] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,169] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,698] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:22,709] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:22,731] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:22,743] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:22,755] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:22,771] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:22,782] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:23,415] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,430] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,442] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,456] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,470] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,487] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,498] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,524] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:24,200] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,210] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,219] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,231] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,250] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,262] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,272] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,898] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:24,903] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:24,907] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:24,932] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:24,957] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:24,979] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:24,991] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:25,011] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:25,616] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,625] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,648] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,658] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,669] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,681] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,690] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,707] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,378] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,384] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,389] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,397] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,402] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,415] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,428] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:27,008] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,018] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,026] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,037] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,046] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,054] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,062] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,070] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,735] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:27,745] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:27,755] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:27,769] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:27,783] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:27,794] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:27,804] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:28,454] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,472] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,481] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,489] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,499] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,533] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,543] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,556] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:29,212] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:29,222] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:29,233] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:29,246] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:29,258] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:29,270] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:29,286] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:30,003] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,013] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,038] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,048] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,062] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,074] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,114] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,125] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,856] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:30,876] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:30,885] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:30,897] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:30,914] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:30,940] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:30,952] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:31,655] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,696] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,709] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,718] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,727] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,740] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,757] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,768] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:32,476] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:32,486] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:32,495] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:32,549] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:32,560] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:32,574] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:32,590] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:33,338] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,356] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,368] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,386] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,397] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,409] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,424] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,434] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:34,352] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:34,364] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:34,377] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:34,395] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:34,410] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:34,423] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:34,434] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:35,373] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,397] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,410] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,420] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,437] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,448] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,460] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,473] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:36,288] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:36,297] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:36,306] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:36,326] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:36,336] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:36,351] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:36,365] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:37,164] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,173] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,182] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,192] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,204] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,232] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,238] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,252] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:38,084] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:38,093] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:38,106] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:38,122] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:38,140] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:38,181] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:38,206] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:39,094] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,111] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,132] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,150] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,174] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,190] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,197] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,212] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:40,009] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,094] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,105] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,128] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,149] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,173] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,189] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,200] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,952] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:40,973] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:40,986] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:40,999] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:41,013] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:41,022] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:41,033] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:41,819] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,832] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,845] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,878] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,886] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,893] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,925] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,935] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:42,562] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:42,589] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:42,621] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:42,634] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:42,644] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:42,657] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:42,668] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:43,380] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,389] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,436] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,448] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,462] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,472] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,486] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,496] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:44,346] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:44,356] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:44,364] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:44,374] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:44,389] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:44,398] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:44,420] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:45,226] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以及意义感'} +[2023-03-30 23:26:45,235] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以及意义感'} +[2023-03-30 23:26:45,258] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以及意义感'} +[2023-03-30 23:26:45,273] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以及意义感'} +[2023-03-30 23:26:45,295] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以及意义感'} +[2023-03-30 23:26:45,306] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以及意义感'} +[2023-03-30 23:26:46,380] [ INFO] - client punctuation restored msg={'result': '第一部分是认知部分,该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理,让学生对设备有大致的认知。随后使用真实传感器的内部构造图,辅以文字说明,进一步帮助学生对传感器有更深刻的印象,最后结合具体的实践应用,提升学生对实训的兴趣以及意义感。'} +[2023-03-30 23:27:01,059] [ INFO] - client final receive msg={'status': 'ok', 'signal': 'finished', 'result': '第一部分是认知部分,该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理,让学生对设备有大致的认知。随后使用真实传感器的内部构造图,辅以文字说明,进一步帮助学生对传感器有更深刻的印象,最后结合具体的实践应用,提升学生对实训的兴趣以及意义感。', 'times': [{'w': '第', 'bg': 0.0, 'ed': 0.36}, {'w': '一', 'bg': 0.36, 'ed': 0.48}, {'w': '部', 'bg': 0.48, 'ed': 0.62}, {'w': '分', 'bg': 0.62, 'ed': 0.8200000000000001}, {'w': '是', 'bg': 0.8200000000000001, 'ed': 1.08}, {'w': '认', 'bg': 1.08, 'ed': 1.28}, {'w': '知', 'bg': 1.28, 'ed': 1.44}, {'w': '部', 'bg': 1.44, 'ed': 1.58}, {'w': '分', 'bg': 1.58, 'ed': 2.1}, {'w': '该', 'bg': 2.1, 'ed': 2.6}, {'w': '部', 'bg': 2.6, 'ed': 2.72}, {'w': '分', 'bg': 2.72, 'ed': 2.94}, {'w': '通', 'bg': 2.94, 'ed': 3.16}, {'w': '过', 'bg': 3.16, 'ed': 3.36}, {'w': '示', 'bg': 3.36, 'ed': 3.54}, {'w': '意', 'bg': 3.54, 'ed': 3.68}, {'w': '图', 'bg': 3.68, 'ed': 3.9}, {'w': '和', 'bg': 3.9, 'ed': 4.14}, {'w': '文', 'bg': 4.14, 'ed': 4.32}, {'w': '本', 'bg': 4.32, 'ed': 4.46}, {'w': '的', 'bg': 4.46, 'ed': 4.58}, {'w': '形', 'bg': 4.58, 'ed': 4.72}, {'w': '式', 'bg': 4.72, 'ed': 5.0}, {'w': '向', 'bg': 5.0, 'ed': 5.32}, {'w': '学', 'bg': 5.32, 'ed': 5.5}, {'w': '生', 'bg': 5.5, 'ed': 5.66}, {'w': '讲', 'bg': 5.66, 'ed': 5.86}, {'w': '解', 'bg': 5.86, 'ed': 6.18}, {'w': '主', 'bg': 6.18, 'ed': 6.46}, {'w': '要', 'bg': 6.46, 'ed': 6.62}, {'w': '传', 'bg': 6.62, 'ed': 6.8}, {'w': '感', 'bg': 6.8, 'ed': 7.0}, {'w': '器', 'bg': 7.0, 'ed': 7.16}, {'w': '的', 'bg': 7.16, 'ed': 7.28}, {'w': '工', 'bg': 7.28, 'ed': 7.44}, {'w': '作', 'bg': 7.44, 'ed': 7.6000000000000005}, {'w': '原', 'bg': 7.6000000000000005, 'ed': 7.74}, {'w': '理', 'bg': 7.74, 'ed': 8.06}, {'w': '让', 'bg': 8.06, 'ed': 8.44}, {'w': '学', 'bg': 8.44, 'ed': 8.64}, {'w': '生', 'bg': 8.64, 'ed': 8.84}, {'w': '对', 'bg': 8.84, 'ed': 9.06}, {'w': '设', 'bg': 9.06, 'ed': 9.24}, {'w': '备', 'bg': 9.24, 'ed': 9.52}, {'w': '有', 'bg': 9.52, 'ed': 9.86}, {'w': '大', 'bg': 9.86, 'ed': 10.1}, {'w': '致', 'bg': 10.1, 'ed': 10.24}, {'w': '的', 'bg': 10.24, 'ed': 10.36}, {'w': '认', 'bg': 10.36, 'ed': 10.5}, {'w': '知', 'bg': 10.5, 'ed': 11.040000000000001}, {'w': '随', 'bg': 11.040000000000001, 'ed': 11.56}, {'w': '后', 'bg': 11.56, 'ed': 11.82}, {'w': '使', 'bg': 11.82, 'ed': 12.1}, {'w': '用', 'bg': 12.1, 'ed': 12.26}, {'w': '真', 'bg': 12.26, 'ed': 12.44}, {'w': '实', 'bg': 12.44, 'ed': 12.620000000000001}, {'w': '传', 'bg': 12.620000000000001, 'ed': 12.780000000000001}, {'w': '感', 'bg': 12.780000000000001, 'ed': 12.94}, {'w': '器', 'bg': 12.94, 'ed': 13.1}, {'w': '的', 'bg': 13.1, 'ed': 13.26}, {'w': '内', 'bg': 13.26, 'ed': 13.42}, {'w': '部', 'bg': 13.42, 'ed': 13.56}, {'w': '构', 'bg': 13.56, 'ed': 13.700000000000001}, {'w': '造', 'bg': 13.700000000000001, 'ed': 13.86}, {'w': '图', 'bg': 13.86, 'ed': 14.280000000000001}, {'w': '辅', 'bg': 14.280000000000001, 'ed': 14.66}, {'w': '以', 'bg': 14.66, 'ed': 14.82}, {'w': '文', 'bg': 14.82, 'ed': 15.0}, {'w': '字', 'bg': 15.0, 'ed': 15.16}, {'w': '说', 'bg': 15.16, 'ed': 15.32}, {'w': '明', 'bg': 15.32, 'ed': 15.72}, {'w': '进', 'bg': 15.72, 'ed': 16.1}, {'w': '一', 'bg': 16.1, 'ed': 16.2}, {'w': '步', 'bg': 16.2, 'ed': 16.32}, {'w': '帮', 'bg': 16.32, 'ed': 16.48}, {'w': '助', 'bg': 16.48, 'ed': 16.66}, {'w': '学', 'bg': 16.66, 'ed': 16.82}, {'w': '生', 'bg': 16.82, 'ed': 17.12}, {'w': '对', 'bg': 17.12, 'ed': 17.48}, {'w': '传', 'bg': 17.48, 'ed': 17.66}, {'w': '感', 'bg': 17.66, 'ed': 17.84}, {'w': '器', 'bg': 17.84, 'ed': 18.12}, {'w': '有', 'bg': 18.12, 'ed': 18.42}, {'w': '更', 'bg': 18.42, 'ed': 18.66}, {'w': '深', 'bg': 18.66, 'ed': 18.88}, {'w': '刻', 'bg': 18.88, 'ed': 19.04}, {'w': '的', 'bg': 19.04, 'ed': 19.16}, {'w': '印', 'bg': 19.16, 'ed': 19.3}, {'w': '象', 'bg': 19.3, 'ed': 19.8}, {'w': '最', 'bg': 19.8, 'ed': 20.3}, {'w': '后', 'bg': 20.3, 'ed': 20.62}, {'w': '结', 'bg': 20.62, 'ed': 20.96}, {'w': '合', 'bg': 20.96, 'ed': 21.14}, {'w': '具', 'bg': 21.14, 'ed': 21.3}, {'w': '体', 'bg': 21.3, 'ed': 21.42}, {'w': '的', 'bg': 21.42, 'ed': 21.580000000000002}, {'w': '实', 'bg': 21.580000000000002, 'ed': 21.76}, {'w': '践', 'bg': 21.76, 'ed': 21.92}, {'w': '应', 'bg': 21.92, 'ed': 22.080000000000002}, {'w': '用', 'bg': 22.080000000000002, 'ed': 22.44}, {'w': '提', 'bg': 22.44, 'ed': 22.78}, {'w': '升', 'bg': 22.78, 'ed': 22.94}, {'w': '学', 'bg': 22.94, 'ed': 23.12}, {'w': '生', 'bg': 23.12, 'ed': 23.34}, {'w': '对', 'bg': 23.34, 'ed': 23.62}, {'w': '实', 'bg': 23.62, 'ed': 23.82}, {'w': '训', 'bg': 23.82, 'ed': 23.96}, {'w': '的', 'bg': 23.96, 'ed': 24.12}, {'w': '兴', 'bg': 24.12, 'ed': 24.3}, {'w': '趣', 'bg': 24.3, 'ed': 24.6}, {'w': '以', 'bg': 24.6, 'ed': 24.88}, {'w': '及', 'bg': 24.88, 'ed': 25.12}, {'w': '意', 'bg': 25.12, 'ed': 25.34}, {'w': '义', 'bg': 25.34, 'ed': 25.46}, {'w': '感', 'bg': 25.46, 'ed': 26.04}]} +[2023-03-30 23:27:01,060] [ INFO] - audio duration: 26.04, elapsed time: 46.581613540649414, RTF=1.7888484462614982 +sentences: ['第一部分是认知部分', '该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理', '让学生对设备有大致的认知', '随后使用真实传感器的内部构造图', '辅以文字说明', '进一步帮助学生对传感器有更深刻的印象', '最后结合具体的实践应用', '提升学生对实训的兴趣以及意义感'] +relative_times: [[0.0, 2.1], [2.1, 8.06], [8.06, 11.040000000000001], [11.040000000000001, 14.280000000000001], [14.280000000000001, 15.72], [15.72, 19.8], [19.8, 22.44], [22.44, 26.04]] +[2023-03-30 23:27:01,076] [ INFO] - results saved to /home/fxb/PaddleSpeech-develop/data/认知.srt + ``` diff --git a/demos/streaming_asr_server/README_cn.md b/demos/streaming_asr_server/README_cn.md index 26a6ce40..bbddd693 100644 --- a/demos/streaming_asr_server/README_cn.md +++ b/demos/streaming_asr_server/README_cn.md @@ -578,3 +578,354 @@ bash server.sh [2022-05-07 11:11:18,915] [ INFO] - audio duration: 4.9968125, elapsed time: 15.928460597991943, RTF=3.187724293835709 [2022-05-07 11:11:18,916] [ INFO] - asr websocket client finished : 我认为跑步最重要的就是给我带来了身体健康 ``` + +## 从音频文件(.wav 格式 或者.mp3 格式)生成字幕文件 (.srt 格式) + +**注意:** 默认部署在 `cpu` 设备上,可以通过修改服务配置文件中 `device` 参数将语音识别和标点预测部署在不同的 `gpu` 上。 + +使用 `streaming_asr_server.py` 和 `punc_server.py` 两个服务,分别启动流式语音识别和标点预测服务。调用 `websocket_client.py` 脚本可以同时调用流式语音识别和标点预测服务,将会生成对应的字幕文件(.srt格式)。 + +**使用该脚本前需要安装mffpeg** + +**应该在对应的`.../demos/streaming_asr_server/`目录下运行以下脚本** + +### 1. 启动服务端 + +```bash +Note: streaming speech recognition and punctuation prediction are configured on different graphics cards through configuration files +paddlespeech_server start --config_file ./conf/ws_conformer_wenetspeech_application.yaml +``` + +Open another terminal run the following commands: +```bash +paddlespeech_server start --config_file conf/punc_application.yaml +``` + +### 2. 启动客户端 + + ```bash + python3 local/websocket_client_srt.py --server_ip 127.0.0.1 --port 8090 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ../../data/认知.mp3 + ``` + Output: + ```text + [2023-03-30 23:26:13,991] [ INFO] - Start to do streaming asr client +[2023-03-30 23:26:13,994] [ INFO] - asr websocket client start +[2023-03-30 23:26:13,994] [ INFO] - endpoint: http://127.0.0.1:8190/paddlespeech/text +[2023-03-30 23:26:13,994] [ INFO] - endpoint: ws://127.0.0.1:8090/paddlespeech/asr/streaming +[2023-03-30 23:26:14,475] [ INFO] - /home/fxb/PaddleSpeech-develop/data/认知.mp3 converted to /home/fxb/PaddleSpeech-develop/data/认知.wav +[2023-03-30 23:26:14,476] [ INFO] - start to process the wavscp: /home/fxb/PaddleSpeech-develop/data/认知.wav +[2023-03-30 23:26:14,515] [ INFO] - client receive msg={"status": "ok", "signal": "server_ready"} +[2023-03-30 23:26:14,533] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,545] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,556] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,572] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,588] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,600] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,613] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:14,626] [ INFO] - client receive msg={'result': ''} +[2023-03-30 23:26:15,122] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,135] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,154] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,163] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,175] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,185] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,196] [ INFO] - client receive msg={'result': '第一部'} +[2023-03-30 23:26:15,637] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,648] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,657] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,666] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,676] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,683] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,691] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:15,703] [ INFO] - client receive msg={'result': '第一部分是认'} +[2023-03-30 23:26:16,146] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,159] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,167] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,177] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,187] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,197] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,210] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,694] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,704] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,713] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,725] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,737] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,749] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,759] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:16,770] [ INFO] - client receive msg={'result': '第一部分是认知部分'} +[2023-03-30 23:26:17,279] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,302] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,316] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,332] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,343] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,358] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,373] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通'} +[2023-03-30 23:26:17,958] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:17,971] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:17,987] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:18,000] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:18,017] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:18,028] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:18,038] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:18,049] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图'} +[2023-03-30 23:26:18,653] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:18,689] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:18,701] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:18,712] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:18,723] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:18,750] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:18,767] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本'} +[2023-03-30 23:26:19,295] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,307] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,323] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,332] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,342] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,349] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,373] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:19,389] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式'} +[2023-03-30 23:26:20,046] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,055] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,067] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,076] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,094] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,124] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,135] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生'} +[2023-03-30 23:26:20,732] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,742] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,757] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,770] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,782] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,798] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,815] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:20,834] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解'} +[2023-03-30 23:26:21,390] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:21,405] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:21,416] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:21,428] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:21,448] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:21,459] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:21,473] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感'} +[2023-03-30 23:26:22,065] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,085] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,110] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,118] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,137] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,144] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,154] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,169] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作'} +[2023-03-30 23:26:22,698] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:22,709] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:22,731] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:22,743] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:22,755] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:22,771] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:22,782] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理'} +[2023-03-30 23:26:23,415] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,430] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,442] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,456] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,470] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,487] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,498] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:23,524] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生'} +[2023-03-30 23:26:24,200] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,210] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,219] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,231] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,250] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,262] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,272] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备'} +[2023-03-30 23:26:24,898] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:24,903] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:24,907] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:24,932] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:24,957] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:24,979] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:24,991] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:25,011] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致'} +[2023-03-30 23:26:25,616] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,625] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,648] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,658] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,669] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,681] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,690] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:25,707] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,378] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,384] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,389] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,397] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,402] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,415] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:26,428] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知'} +[2023-03-30 23:26:27,008] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,018] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,026] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,037] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,046] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,054] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,062] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,070] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使'} +[2023-03-30 23:26:27,735] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:27,745] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:27,755] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:27,769] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:27,783] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:27,794] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:27,804] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传'} +[2023-03-30 23:26:28,454] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,472] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,481] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,489] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,499] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,533] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,543] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:28,556] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内'} +[2023-03-30 23:26:29,212] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:29,222] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:29,233] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:29,246] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:29,258] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:29,270] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:29,286] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图'} +[2023-03-30 23:26:30,003] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,013] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,038] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,048] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,062] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,074] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,114] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,125] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅'} +[2023-03-30 23:26:30,856] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:30,876] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:30,885] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:30,897] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:30,914] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:30,940] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:30,952] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说'} +[2023-03-30 23:26:31,655] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,696] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,709] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,718] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,727] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,740] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,757] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:31,768] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明'} +[2023-03-30 23:26:32,476] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:32,486] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:32,495] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:32,549] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:32,560] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:32,574] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:32,590] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助'} +[2023-03-30 23:26:33,338] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,356] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,368] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,386] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,397] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,409] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,424] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:33,434] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生'} +[2023-03-30 23:26:34,352] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:34,364] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:34,377] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:34,395] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:34,410] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:34,423] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:34,434] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感'} +[2023-03-30 23:26:35,373] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,397] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,410] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,420] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,437] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,448] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,460] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:35,473] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有'} +[2023-03-30 23:26:36,288] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:36,297] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:36,306] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:36,326] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:36,336] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:36,351] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:36,365] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的'} +[2023-03-30 23:26:37,164] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,173] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,182] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,192] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,204] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,232] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,238] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:37,252] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象'} +[2023-03-30 23:26:38,084] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:38,093] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:38,106] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:38,122] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:38,140] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:38,181] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:38,206] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后'} +[2023-03-30 23:26:39,094] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,111] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,132] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,150] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,174] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,190] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,197] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:39,212] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合'} +[2023-03-30 23:26:40,009] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,094] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,105] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,128] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,149] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,173] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,189] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,200] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实'} +[2023-03-30 23:26:40,952] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:40,973] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:40,986] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:40,999] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:41,013] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:41,022] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:41,033] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用'} +[2023-03-30 23:26:41,819] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,832] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,845] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,878] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,886] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,893] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,925] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:41,935] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升'} +[2023-03-30 23:26:42,562] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:42,589] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:42,621] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:42,634] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:42,644] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:42,657] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:42,668] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对'} +[2023-03-30 23:26:43,380] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,389] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,436] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,448] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,462] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,472] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,486] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:43,496] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴'} +[2023-03-30 23:26:44,346] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:44,356] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:44,364] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:44,374] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:44,389] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:44,398] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:44,420] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以'} +[2023-03-30 23:26:45,226] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以及意义感'} +[2023-03-30 23:26:45,235] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以及意义感'} +[2023-03-30 23:26:45,258] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以及意义感'} +[2023-03-30 23:26:45,273] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以及意义感'} +[2023-03-30 23:26:45,295] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以及意义感'} +[2023-03-30 23:26:45,306] [ INFO] - client receive msg={'result': '第一部分是认知部分该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理让学生对设备有大致的认知随后使用真实传感器的内部构造图辅以文字说明进一步帮助学生对传感器有更深刻的印象最后结合具体的实践应用提升学生对实训的兴趣以及意义感'} +[2023-03-30 23:26:46,380] [ INFO] - client punctuation restored msg={'result': '第一部分是认知部分,该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理,让学生对设备有大致的认知。随后使用真实传感器的内部构造图,辅以文字说明,进一步帮助学生对传感器有更深刻的印象,最后结合具体的实践应用,提升学生对实训的兴趣以及意义感。'} +[2023-03-30 23:27:01,059] [ INFO] - client final receive msg={'status': 'ok', 'signal': 'finished', 'result': '第一部分是认知部分,该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理,让学生对设备有大致的认知。随后使用真实传感器的内部构造图,辅以文字说明,进一步帮助学生对传感器有更深刻的印象,最后结合具体的实践应用,提升学生对实训的兴趣以及意义感。', 'times': [{'w': '第', 'bg': 0.0, 'ed': 0.36}, {'w': '一', 'bg': 0.36, 'ed': 0.48}, {'w': '部', 'bg': 0.48, 'ed': 0.62}, {'w': '分', 'bg': 0.62, 'ed': 0.8200000000000001}, {'w': '是', 'bg': 0.8200000000000001, 'ed': 1.08}, {'w': '认', 'bg': 1.08, 'ed': 1.28}, {'w': '知', 'bg': 1.28, 'ed': 1.44}, {'w': '部', 'bg': 1.44, 'ed': 1.58}, {'w': '分', 'bg': 1.58, 'ed': 2.1}, {'w': '该', 'bg': 2.1, 'ed': 2.6}, {'w': '部', 'bg': 2.6, 'ed': 2.72}, {'w': '分', 'bg': 2.72, 'ed': 2.94}, {'w': '通', 'bg': 2.94, 'ed': 3.16}, {'w': '过', 'bg': 3.16, 'ed': 3.36}, {'w': '示', 'bg': 3.36, 'ed': 3.54}, {'w': '意', 'bg': 3.54, 'ed': 3.68}, {'w': '图', 'bg': 3.68, 'ed': 3.9}, {'w': '和', 'bg': 3.9, 'ed': 4.14}, {'w': '文', 'bg': 4.14, 'ed': 4.32}, {'w': '本', 'bg': 4.32, 'ed': 4.46}, {'w': '的', 'bg': 4.46, 'ed': 4.58}, {'w': '形', 'bg': 4.58, 'ed': 4.72}, {'w': '式', 'bg': 4.72, 'ed': 5.0}, {'w': '向', 'bg': 5.0, 'ed': 5.32}, {'w': '学', 'bg': 5.32, 'ed': 5.5}, {'w': '生', 'bg': 5.5, 'ed': 5.66}, {'w': '讲', 'bg': 5.66, 'ed': 5.86}, {'w': '解', 'bg': 5.86, 'ed': 6.18}, {'w': '主', 'bg': 6.18, 'ed': 6.46}, {'w': '要', 'bg': 6.46, 'ed': 6.62}, {'w': '传', 'bg': 6.62, 'ed': 6.8}, {'w': '感', 'bg': 6.8, 'ed': 7.0}, {'w': '器', 'bg': 7.0, 'ed': 7.16}, {'w': '的', 'bg': 7.16, 'ed': 7.28}, {'w': '工', 'bg': 7.28, 'ed': 7.44}, {'w': '作', 'bg': 7.44, 'ed': 7.6000000000000005}, {'w': '原', 'bg': 7.6000000000000005, 'ed': 7.74}, {'w': '理', 'bg': 7.74, 'ed': 8.06}, {'w': '让', 'bg': 8.06, 'ed': 8.44}, {'w': '学', 'bg': 8.44, 'ed': 8.64}, {'w': '生', 'bg': 8.64, 'ed': 8.84}, {'w': '对', 'bg': 8.84, 'ed': 9.06}, {'w': '设', 'bg': 9.06, 'ed': 9.24}, {'w': '备', 'bg': 9.24, 'ed': 9.52}, {'w': '有', 'bg': 9.52, 'ed': 9.86}, {'w': '大', 'bg': 9.86, 'ed': 10.1}, {'w': '致', 'bg': 10.1, 'ed': 10.24}, {'w': '的', 'bg': 10.24, 'ed': 10.36}, {'w': '认', 'bg': 10.36, 'ed': 10.5}, {'w': '知', 'bg': 10.5, 'ed': 11.040000000000001}, {'w': '随', 'bg': 11.040000000000001, 'ed': 11.56}, {'w': '后', 'bg': 11.56, 'ed': 11.82}, {'w': '使', 'bg': 11.82, 'ed': 12.1}, {'w': '用', 'bg': 12.1, 'ed': 12.26}, {'w': '真', 'bg': 12.26, 'ed': 12.44}, {'w': '实', 'bg': 12.44, 'ed': 12.620000000000001}, {'w': '传', 'bg': 12.620000000000001, 'ed': 12.780000000000001}, {'w': '感', 'bg': 12.780000000000001, 'ed': 12.94}, {'w': '器', 'bg': 12.94, 'ed': 13.1}, {'w': '的', 'bg': 13.1, 'ed': 13.26}, {'w': '内', 'bg': 13.26, 'ed': 13.42}, {'w': '部', 'bg': 13.42, 'ed': 13.56}, {'w': '构', 'bg': 13.56, 'ed': 13.700000000000001}, {'w': '造', 'bg': 13.700000000000001, 'ed': 13.86}, {'w': '图', 'bg': 13.86, 'ed': 14.280000000000001}, {'w': '辅', 'bg': 14.280000000000001, 'ed': 14.66}, {'w': '以', 'bg': 14.66, 'ed': 14.82}, {'w': '文', 'bg': 14.82, 'ed': 15.0}, {'w': '字', 'bg': 15.0, 'ed': 15.16}, {'w': '说', 'bg': 15.16, 'ed': 15.32}, {'w': '明', 'bg': 15.32, 'ed': 15.72}, {'w': '进', 'bg': 15.72, 'ed': 16.1}, {'w': '一', 'bg': 16.1, 'ed': 16.2}, {'w': '步', 'bg': 16.2, 'ed': 16.32}, {'w': '帮', 'bg': 16.32, 'ed': 16.48}, {'w': '助', 'bg': 16.48, 'ed': 16.66}, {'w': '学', 'bg': 16.66, 'ed': 16.82}, {'w': '生', 'bg': 16.82, 'ed': 17.12}, {'w': '对', 'bg': 17.12, 'ed': 17.48}, {'w': '传', 'bg': 17.48, 'ed': 17.66}, {'w': '感', 'bg': 17.66, 'ed': 17.84}, {'w': '器', 'bg': 17.84, 'ed': 18.12}, {'w': '有', 'bg': 18.12, 'ed': 18.42}, {'w': '更', 'bg': 18.42, 'ed': 18.66}, {'w': '深', 'bg': 18.66, 'ed': 18.88}, {'w': '刻', 'bg': 18.88, 'ed': 19.04}, {'w': '的', 'bg': 19.04, 'ed': 19.16}, {'w': '印', 'bg': 19.16, 'ed': 19.3}, {'w': '象', 'bg': 19.3, 'ed': 19.8}, {'w': '最', 'bg': 19.8, 'ed': 20.3}, {'w': '后', 'bg': 20.3, 'ed': 20.62}, {'w': '结', 'bg': 20.62, 'ed': 20.96}, {'w': '合', 'bg': 20.96, 'ed': 21.14}, {'w': '具', 'bg': 21.14, 'ed': 21.3}, {'w': '体', 'bg': 21.3, 'ed': 21.42}, {'w': '的', 'bg': 21.42, 'ed': 21.580000000000002}, {'w': '实', 'bg': 21.580000000000002, 'ed': 21.76}, {'w': '践', 'bg': 21.76, 'ed': 21.92}, {'w': '应', 'bg': 21.92, 'ed': 22.080000000000002}, {'w': '用', 'bg': 22.080000000000002, 'ed': 22.44}, {'w': '提', 'bg': 22.44, 'ed': 22.78}, {'w': '升', 'bg': 22.78, 'ed': 22.94}, {'w': '学', 'bg': 22.94, 'ed': 23.12}, {'w': '生', 'bg': 23.12, 'ed': 23.34}, {'w': '对', 'bg': 23.34, 'ed': 23.62}, {'w': '实', 'bg': 23.62, 'ed': 23.82}, {'w': '训', 'bg': 23.82, 'ed': 23.96}, {'w': '的', 'bg': 23.96, 'ed': 24.12}, {'w': '兴', 'bg': 24.12, 'ed': 24.3}, {'w': '趣', 'bg': 24.3, 'ed': 24.6}, {'w': '以', 'bg': 24.6, 'ed': 24.88}, {'w': '及', 'bg': 24.88, 'ed': 25.12}, {'w': '意', 'bg': 25.12, 'ed': 25.34}, {'w': '义', 'bg': 25.34, 'ed': 25.46}, {'w': '感', 'bg': 25.46, 'ed': 26.04}]} +[2023-03-30 23:27:01,060] [ INFO] - audio duration: 26.04, elapsed time: 46.581613540649414, RTF=1.7888484462614982 +sentences: ['第一部分是认知部分', '该部分通过示意图和文本的形式向学生讲解主要传感器的工作原理', '让学生对设备有大致的认知', '随后使用真实传感器的内部构造图', '辅以文字说明', '进一步帮助学生对传感器有更深刻的印象', '最后结合具体的实践应用', '提升学生对实训的兴趣以及意义感'] +relative_times: [[0.0, 2.1], [2.1, 8.06], [8.06, 11.040000000000001], [11.040000000000001, 14.280000000000001], [14.280000000000001, 15.72], [15.72, 19.8], [19.8, 22.44], [22.44, 26.04]] +[2023-03-30 23:27:01,076] [ INFO] - results saved to /home/fxb/PaddleSpeech-develop/data/认知.srt + ``` diff --git a/demos/streaming_asr_server/local/websocket_client_srt.py b/demos/streaming_asr_server/local/websocket_client_srt.py new file mode 100644 index 00000000..02fea484 --- /dev/null +++ b/demos/streaming_asr_server/local/websocket_client_srt.py @@ -0,0 +1,162 @@ +#!/usr/bin/python +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# calc avg RTF(NOT Accurate): grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}' +# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --punc.server_ip 127.0.0.1 --punc.port 8190 --wavfile ./zh.wav +# python3 websocket_client.py --server_ip 127.0.0.1 --port 8290 --wavfile ./zh.wav +import argparse +import asyncio +import codecs +import os +from pydub import AudioSegment +import re + +from paddlespeech.cli.log import logger +from paddlespeech.server.utils.audio_handler import ASRWsAudioHandler + +def convert_to_wav(input_file): + # Load audio file + audio = AudioSegment.from_file(input_file) + + # Set parameters for audio file + audio = audio.set_channels(1) + audio = audio.set_frame_rate(16000) + + # Create output filename + output_file = os.path.splitext(input_file)[0] + ".wav" + + # Export audio file as WAV + audio.export(output_file, format="wav") + + logger.info(f"{input_file} converted to {output_file}") + +def format_time(sec): + # Convert seconds to SRT format (HH:MM:SS,ms) + hours = int(sec/3600) + minutes = int((sec%3600)/60) + seconds = int(sec%60) + milliseconds = int((sec%1)*1000) + return f'{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}' + +def results2srt(results, srt_file): + """convert results from paddlespeech to srt format for subtitle + Args: + results (dict): results from paddlespeech + """ + # times contains start and end time of each word + times = results['times'] + # result contains the whole sentence including punctuation + result = results['result'] + # split result into several sencences by ',' and '。' + sentences = re.split(',|。', result)[:-1] + # print("sentences: ", sentences) + # generate relative time for each sentence in sentences + relative_times = [] + word_i = 0 + for sentence in sentences: + relative_times.append([]) + for word in sentence: + if relative_times[-1] == []: + relative_times[-1].append(times[word_i]['bg']) + if len(relative_times[-1]) == 1: + relative_times[-1].append(times[word_i]['ed']) + else: + relative_times[-1][1] = times[word_i]['ed'] + word_i += 1 + # print("relative_times: ", relative_times) + # generate srt file acoording to relative_times and sentences + with open(srt_file, 'w') as f: + for i in range(len(sentences)): + # Write index number + f.write(str(i+1)+'\n') + + # Write start and end times + start = format_time(relative_times[i][0]) + end = format_time(relative_times[i][1]) + f.write(start + ' --> ' + end + '\n') + + # Write text + f.write(sentences[i]+'\n\n') + logger.info(f"results saved to {srt_file}") + +def main(args): + logger.info("asr websocket client start") + handler = ASRWsAudioHandler( + args.server_ip, + args.port, + endpoint=args.endpoint, + punc_server_ip=args.punc_server_ip, + punc_server_port=args.punc_server_port) + loop = asyncio.get_event_loop() + + # check if the wav file is mp3 format + # if so, convert it to wav format using convert_to_wav function + if args.wavfile and os.path.exists(args.wavfile): + if args.wavfile.endswith(".mp3"): + convert_to_wav(args.wavfile) + args.wavfile = args.wavfile.replace(".mp3", ".wav") + + # support to process single audio file + if args.wavfile and os.path.exists(args.wavfile): + logger.info(f"start to process the wavscp: {args.wavfile}") + result = loop.run_until_complete(handler.run(args.wavfile)) + # result = result["result"] + # logger.info(f"asr websocket client finished : {result}") + results2srt(result, args.wavfile.replace(".wav", ".srt")) + + # support to process batch audios from wav.scp + if args.wavscp and os.path.exists(args.wavscp): + logger.info(f"start to process the wavscp: {args.wavscp}") + with codecs.open(args.wavscp, 'r', encoding='utf-8') as f,\ + codecs.open("result.txt", 'w', encoding='utf-8') as w: + for line in f: + utt_name, utt_path = line.strip().split() + result = loop.run_until_complete(handler.run(utt_path)) + result = result["result"] + w.write(f"{utt_name} {result}\n") + + +if __name__ == "__main__": + logger.info("Start to do streaming asr client") + parser = argparse.ArgumentParser() + parser.add_argument( + '--server_ip', type=str, default='127.0.0.1', help='server ip') + parser.add_argument('--port', type=int, default=8090, help='server port') + parser.add_argument( + '--punc.server_ip', + type=str, + default=None, + dest="punc_server_ip", + help='Punctuation server ip') + parser.add_argument( + '--punc.port', + type=int, + default=8091, + dest="punc_server_port", + help='Punctuation server port') + parser.add_argument( + "--endpoint", + type=str, + default="/paddlespeech/asr/streaming", + help="ASR websocket endpoint") + parser.add_argument( + "--wavfile", + action="store", + help="wav file path ", + default="./16_audio.wav") + parser.add_argument( + "--wavscp", type=str, default=None, help="The batch audios dict text") + args = parser.parse_args() + + main(args) diff --git a/docs/source/tts/quick_start.md b/docs/source/tts/quick_start.md index d8dbc646..d2a1b4ec 100644 --- a/docs/source/tts/quick_start.md +++ b/docs/source/tts/quick_start.md @@ -79,8 +79,8 @@ checkpoint_name ├── snapshot_iter_*.pdz ├── speech_stats.npy ├── phone_id_map.txt -├── spk_id_map.txt (optimal) -└── tone_id_map.txt (optimal) +├── spk_id_map.txt (optional) +└── tone_id_map.txt (optional) ``` **Vocoders:** ```text diff --git a/docs/source/tts/quick_start_cn.md b/docs/source/tts/quick_start_cn.md index c56d9bb4..ba259643 100644 --- a/docs/source/tts/quick_start_cn.md +++ b/docs/source/tts/quick_start_cn.md @@ -87,8 +87,8 @@ checkpoint_name ├── snapshot_iter_*.pdz ├── speech_stats.npy ├── phone_id_map.txt -├── spk_id_map.txt (optimal) -└── tone_id_map.txt (optimal) +├── spk_id_map.txt (optional) +└── tone_id_map.txt (optional) ``` **Vocoders:** ```text diff --git a/docs/tutorial/st/st_tutorial.ipynb b/docs/tutorial/st/st_tutorial.ipynb index 2fb85053..e755beba 100644 --- a/docs/tutorial/st/st_tutorial.ipynb +++ b/docs/tutorial/st/st_tutorial.ipynb @@ -62,7 +62,7 @@ "collapsed": false }, "source": [ - "# 使用Transformer进行端到端语音翻译的的基本流程\n", + "# 使用Transformer进行端到端语音翻译的基本流程\n", "## 基础模型\n", "由于 ASR 章节已经介绍了 Transformer 以及语音特征抽取,在此便不做过多介绍,感兴趣的同学可以去相关章节进行了解。\n", "\n", diff --git a/docs/tutorial/tts/tts_tutorial.ipynb b/docs/tutorial/tts/tts_tutorial.ipynb index 583adb01..0cecb680 100644 --- a/docs/tutorial/tts/tts_tutorial.ipynb +++ b/docs/tutorial/tts/tts_tutorial.ipynb @@ -464,7 +464,7 @@ "
FastSpeech2 网络结构图

\n", "\n", "\n", - "PaddleSpeech TTS 实现的 FastSpeech2 与论文不同的地方在于,我们使用的的是 phone 级别的 `pitch` 和 `energy`(与 FastPitch 类似),这样的合成结果可以更加**稳定**。\n", + "PaddleSpeech TTS 实现的 FastSpeech2 与论文不同的地方在于,我们使用的是 phone 级别的 `pitch` 和 `energy`(与 FastPitch 类似),这样的合成结果可以更加**稳定**。\n", "
\n", "
FastPitch 网络结构图

\n", "\n", diff --git a/examples/aishell/asr0/local/train.sh b/examples/aishell/asr0/local/train.sh index 2b71b7f7..c0da3325 100755 --- a/examples/aishell/asr0/local/train.sh +++ b/examples/aishell/asr0/local/train.sh @@ -1,6 +1,6 @@ #!/bin/bash -if [ $# -lt 2 ] && [ $# -gt 3 ];then +if [ $# -lt 2 ] || [ $# -gt 3 ];then echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" exit -1 fi diff --git a/examples/aishell/asr1/local/test.sh b/examples/aishell/asr1/local/test.sh index 26926b4a..8487e990 100755 --- a/examples/aishell/asr1/local/test.sh +++ b/examples/aishell/asr1/local/test.sh @@ -1,15 +1,21 @@ #!/bin/bash -if [ $# != 3 ];then - echo "usage: ${0} config_path decode_config_path ckpt_path_prefix" - exit -1 -fi +set -e stage=0 stop_stage=100 + +source utils/parse_options.sh || exit 1; + ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." + +if [ $# != 3 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix" + exit -1 +fi + config_path=$1 decode_config_path=$2 ckpt_prefix=$3 @@ -92,6 +98,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then fi if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then + echo "using sclite to compute cer..." # format the reference test file for sclite python utils/format_rsl.py \ --origin_ref data/manifest.test.raw \ diff --git a/examples/aishell/asr1/local/train.sh b/examples/aishell/asr1/local/train.sh index bfa8dd97..3d4f052a 100755 --- a/examples/aishell/asr1/local/train.sh +++ b/examples/aishell/asr1/local/train.sh @@ -17,7 +17,7 @@ if [ ${seed} != 0 ]; then echo "using seed $seed & FLAGS_cudnn_deterministic=True ..." fi -if [ $# -lt 2 ] && [ $# -gt 3 ];then +if [ $# -lt 2 ] || [ $# -gt 3 ];then echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" exit -1 fi diff --git a/examples/aishell/asr3/local/train.sh b/examples/aishell/asr3/local/train.sh index e51e3d34..33fef0fd 100755 --- a/examples/aishell/asr3/local/train.sh +++ b/examples/aishell/asr3/local/train.sh @@ -1,6 +1,6 @@ #!/bin/bash -if [ $# -lt 2 ] && [ $# -gt 3 ];then +if [ $# -lt 2 ] || [ $# -gt 3 ];then echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)" exit -1 fi diff --git a/examples/csmsc/jets/README.md b/examples/csmsc/jets/README.md new file mode 100644 index 00000000..07dade0e --- /dev/null +++ b/examples/csmsc/jets/README.md @@ -0,0 +1,108 @@ +# JETS with CSMSC +This example contains code used to train a [JETS](https://arxiv.org/abs/2203.16852v1) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). + +## Dataset +### Download and Extract +Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/source). + +### Get MFA Result and Extract +We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes and durations for JETS. +You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo. + +## Get Started +Assume the path to the dataset is `~/datasets/BZNSYP`. +Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`. +Run the command below to +1. **source path**. +2. preprocess the dataset. +3. train the model. +4. synthesize wavs. + - synthesize waveform from `metadata.jsonl`. + - synthesize waveform from a text file. + +```bash +./run.sh +``` +You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset. +```bash +./run.sh --stage 0 --stop-stage 0 +``` +### Data Preprocessing +```bash +./local/preprocess.sh ${conf_path} +``` +When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below. + +```text +dump +├── dev +│   ├── norm +│   └── raw +├── phone_id_map.txt +├── speaker_id_map.txt +├── test +│   ├── norm +│   └── raw +└── train + ├── feats_stats.npy + ├── norm + └── raw +``` +The dataset is split into 3 parts, namely `train`, `dev`, and` test`, each of which contains a `norm` and `raw` subfolder. The raw folder contains wave、mel spectrogram、speech、pitch and energy features of each utterance, while the norm folder contains normalized ones. The statistics used to normalize features are computed from the training set, which is located in `dump/train/feats_stats.npy`. + +Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains phones, text_lengths, the path of feats, feats_lengths, the path of pitch features, the path of energy features, the path of raw waves, speaker, and the id of each utterance. + +### Model Training +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} +``` +`./local/train.sh` calls `${BIN_DIR}/train.py`. +Here's the complete help message. +```text +usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] + [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] + [--ngpu NGPU] [--phones-dict PHONES_DICT] + +Train a JETS model. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG config file to overwrite default config. + --train-metadata TRAIN_METADATA + training data. + --dev-metadata DEV_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. + --phones-dict PHONES_DICT + phone vocabulary file. +``` +1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. +2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder. +3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory. +4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. +5. `--phones-dict` is the path of the phone vocabulary file. + +### Synthesizing + +`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`. + +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` + +`./local/synthesize_e2e.sh` calls `${BIN_DIR}/synthesize_e2e.py`, which can synthesize waveform from text file. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` + +## Pretrained Model + +The pretrained model can be downloaded here: + +- [jets_csmsc_ckpt_1.5.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/jets_csmsc_ckpt_1.5.0.zip) + +The static model can be downloaded here: + +- [jets_csmsc_static_1.5.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/jets_csmsc_static_1.5.0.zip) diff --git a/examples/csmsc/jets/conf/default.yaml b/examples/csmsc/jets/conf/default.yaml new file mode 100644 index 00000000..1dafd20c --- /dev/null +++ b/examples/csmsc/jets/conf/default.yaml @@ -0,0 +1,224 @@ +# This configuration tested on 4 GPUs (V100) with 32GB GPU +# memory. It takes around 2 weeks to finish the training +# but 100k iters model should generate reasonable results. +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### + +n_mels: 80 +fs: 22050 # sr +n_fft: 1024 # FFT size (samples). +n_shift: 256 # Hop size (samples). 12.5ms +win_length: null # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +fmin: 0 # minimum frequency for Mel basis +fmax: null # maximum frequency for Mel basis +f0min: 80 # Minimum f0 for pitch extraction. +f0max: 400 # Maximum f0 for pitch extraction. + + +########################################################## +# TTS MODEL SETTING # +########################################################## +model: + # generator related + generator_type: jets_generator + generator_params: + adim: 256 # attention dimension + aheads: 2 # number of attention heads + elayers: 4 # number of encoder layers + eunits: 1024 # number of encoder ff units + dlayers: 4 # number of decoder layers + dunits: 1024 # number of decoder ff units + positionwise_layer_type: conv1d # type of position-wise layer + positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer + duration_predictor_layers: 2 # number of layers of duration predictor + duration_predictor_chans: 256 # number of channels of duration predictor + duration_predictor_kernel_size: 3 # filter size of duration predictor + use_masking: True # whether to apply masking for padded part in loss calculation + encoder_normalize_before: True # whether to perform layer normalization before the input + decoder_normalize_before: True # whether to perform layer normalization before the input + encoder_type: transformer # encoder type + decoder_type: transformer # decoder type + conformer_rel_pos_type: latest # relative positional encoding type + conformer_pos_enc_layer_type: rel_pos # conformer positional encoding type + conformer_self_attn_layer_type: rel_selfattn # conformer self-attention type + conformer_activation_type: swish # conformer activation type + use_macaron_style_in_conformer: true # whether to use macaron style in conformer + use_cnn_in_conformer: true # whether to use CNN in conformer + conformer_enc_kernel_size: 7 # kernel size in CNN module of conformer-based encoder + conformer_dec_kernel_size: 31 # kernel size in CNN module of conformer-based decoder + init_type: xavier_uniform # initialization type + init_enc_alpha: 1.0 # initial value of alpha for encoder + init_dec_alpha: 1.0 # initial value of alpha for decoder + transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer + transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding + transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer + transformer_dec_dropout_rate: 0.2 # dropout rate for transformer decoder layer + transformer_dec_positional_dropout_rate: 0.2 # dropout rate for transformer decoder positional encoding + transformer_dec_attn_dropout_rate: 0.2 # dropout rate for transformer decoder attention layer + pitch_predictor_layers: 5 # number of conv layers in pitch predictor + pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor + pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor + pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor + pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch + pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch + stop_gradient_from_pitch_predictor: true # whether to stop the gradient from pitch predictor to encoder + energy_predictor_layers: 2 # number of conv layers in energy predictor + energy_predictor_chans: 256 # number of channels of conv layers in energy predictor + energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor + energy_predictor_dropout: 0.5 # dropout rate in energy predictor + energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy + energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy + stop_gradient_from_energy_predictor: false # whether to stop the gradient from energy predictor to encoder + generator_out_channels: 1 + generator_channels: 512 + generator_global_channels: -1 + generator_kernel_size: 7 + generator_upsample_scales: [8, 8, 2, 2] + generator_upsample_kernel_sizes: [16, 16, 4, 4] + generator_resblock_kernel_sizes: [3, 7, 11] + generator_resblock_dilations: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + generator_use_additional_convs: true + generator_bias: true + generator_nonlinear_activation: "leakyrelu" + generator_nonlinear_activation_params: + negative_slope: 0.1 + generator_use_weight_norm: true + segment_size: 64 # segment size for random windowed discriminator + + # discriminator related + discriminator_type: hifigan_multi_scale_multi_period_discriminator + discriminator_params: + scales: 1 + scale_downsample_pooling: "AvgPool1D" + scale_downsample_pooling_params: + kernel_size: 4 + stride: 2 + padding: 2 + scale_discriminator_params: + in_channels: 1 + out_channels: 1 + kernel_sizes: [15, 41, 5, 3] + channels: 128 + max_downsample_channels: 1024 + max_groups: 16 + bias: True + downsample_scales: [2, 2, 4, 4, 1] + nonlinear_activation: "leakyrelu" + nonlinear_activation_params: + negative_slope: 0.1 + use_weight_norm: True + use_spectral_norm: False + follow_official_norm: False + periods: [2, 3, 5, 7, 11] + period_discriminator_params: + in_channels: 1 + out_channels: 1 + kernel_sizes: [5, 3] + channels: 32 + downsample_scales: [3, 3, 3, 3, 1] + max_downsample_channels: 1024 + bias: True + nonlinear_activation: "leakyrelu" + nonlinear_activation_params: + negative_slope: 0.1 + use_weight_norm: True + use_spectral_norm: False + # others + sampling_rate: 22050 # needed in the inference for saving wav + cache_generator_outputs: True # whether to cache generator outputs in the training +use_alignment_module: False # whether to use alignment module + +########################################################### +# LOSS SETTING # +########################################################### +# loss function related +generator_adv_loss_params: + average_by_discriminators: False # whether to average loss value by #discriminators + loss_type: mse # loss type, "mse" or "hinge" +discriminator_adv_loss_params: + average_by_discriminators: False # whether to average loss value by #discriminators + loss_type: mse # loss type, "mse" or "hinge" +feat_match_loss_params: + average_by_discriminators: False # whether to average loss value by #discriminators + average_by_layers: False # whether to average loss value by #layers of each discriminator + include_final_outputs: True # whether to include final outputs for loss calculation +mel_loss_params: + fs: 22050 # must be the same as the training data + fft_size: 1024 # fft points + hop_size: 256 # hop size + win_length: null # window length + window: hann # window type + num_mels: 80 # number of Mel basis + fmin: 0 # minimum frequency for Mel basis + fmax: null # maximum frequency for Mel basis + log_base: null # null represent natural log + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_adv: 1.0 # loss scaling coefficient for adversarial loss +lambda_mel: 45.0 # loss scaling coefficient for Mel loss +lambda_feat_match: 2.0 # loss scaling coefficient for feat match loss +lambda_var: 1.0 # loss scaling coefficient for duration loss +lambda_align: 2.0 # loss scaling coefficient for KL divergence loss +# others +sampling_rate: 22050 # needed in the inference for saving wav +cache_generator_outputs: True # whether to cache generator outputs in the training + + +# extra module for additional inputs +pitch_extract: dio # pitch extractor type +pitch_extract_conf: + reduction_factor: 1 + use_token_averaged_f0: false +pitch_normalize: global_mvn # normalizer for the pitch feature +energy_extract: energy # energy extractor type +energy_extract_conf: + reduction_factor: 1 + use_token_averaged_energy: false +energy_normalize: global_mvn # normalizer for the energy feature + + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 32 # Batch size. +num_workers: 4 # Number of workers in DataLoader. + +########################################################## +# OPTIMIZER & SCHEDULER SETTING # +########################################################## +# optimizer setting for generator +generator_optimizer_params: + beta1: 0.8 + beta2: 0.99 + epsilon: 1.0e-9 + weight_decay: 0.0 +generator_scheduler: exponential_decay +generator_scheduler_params: + learning_rate: 2.0e-4 + gamma: 0.999875 + +# optimizer setting for discriminator +discriminator_optimizer_params: + beta1: 0.8 + beta2: 0.99 + epsilon: 1.0e-9 + weight_decay: 0.0 +discriminator_scheduler: exponential_decay +discriminator_scheduler_params: + learning_rate: 2.0e-4 + gamma: 0.999875 +generator_first: True # whether to start updating generator first + +########################################################## +# OTHER TRAINING SETTING # +########################################################## +num_snapshots: 10 # max number of snapshots to keep while training +train_max_steps: 350000 # Number of training steps. == total_iters / ngpus, total_iters = 1000000 +save_interval_steps: 1000 # Interval steps to save checkpoint. +eval_interval_steps: 250 # Interval steps to evaluate the network. +seed: 777 # random seed number diff --git a/examples/csmsc/jets/local/inference.sh b/examples/csmsc/jets/local/inference.sh new file mode 100755 index 00000000..30941caa --- /dev/null +++ b/examples/csmsc/jets/local/inference.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +train_output_path=$1 + +stage=0 +stop_stage=0 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/inference.py \ + --inference_dir=${train_output_path}/inference \ + --am=jets_csmsc \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/pd_infer_out \ + --phones_dict=dump/phone_id_map.txt +fi diff --git a/examples/csmsc/jets/local/preprocess.sh b/examples/csmsc/jets/local/preprocess.sh new file mode 100755 index 00000000..60053131 --- /dev/null +++ b/examples/csmsc/jets/local/preprocess.sh @@ -0,0 +1,77 @@ +#!/bin/bash +set -e +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./baker_alignment_tone \ + --output=durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=baker \ + --rootdir=~/datasets/BZNSYP/ \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --num-cpu=20 \ + --cut-sil=True \ + --token_average=True +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="feats" + + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="pitch" + + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="energy" + +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize and covert phone/speaker to id, dev and test should use train's stats + echo "Normalize ..." + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --feats-stats=dump/train/feats_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --feats-stats=dump/train/feats_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --feats-stats=dump/train/feats_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt +fi diff --git a/examples/csmsc/jets/local/synthesize.sh b/examples/csmsc/jets/local/synthesize.sh new file mode 100755 index 00000000..a4b35ec0 --- /dev/null +++ b/examples/csmsc/jets/local/synthesize.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 +stage=0 +stop_stage=0 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize.py \ + --config=${config_path} \ + --ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --phones_dict=dump/phone_id_map.txt \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test +fi diff --git a/examples/csmsc/jets/local/synthesize_e2e.sh b/examples/csmsc/jets/local/synthesize_e2e.sh new file mode 100755 index 00000000..67ae14fa --- /dev/null +++ b/examples/csmsc/jets/local/synthesize_e2e.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +stage=0 +stop_stage=0 + + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize_e2e.py \ + --am=jets_csmsc \ + --config=${config_path} \ + --ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --phones_dict=dump/phone_id_map.txt \ + --output_dir=${train_output_path}/test_e2e \ + --text=${BIN_DIR}/../sentences.txt \ + --inference_dir=${train_output_path}/inference +fi diff --git a/examples/csmsc/jets/local/train.sh b/examples/csmsc/jets/local/train.sh new file mode 100755 index 00000000..d1302f99 --- /dev/null +++ b/examples/csmsc/jets/local/train.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +python3 ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=1 \ + --phones-dict=dump/phone_id_map.txt diff --git a/examples/csmsc/jets/path.sh b/examples/csmsc/jets/path.sh new file mode 100755 index 00000000..73a0af7e --- /dev/null +++ b/examples/csmsc/jets/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=jets +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} diff --git a/examples/csmsc/jets/run.sh b/examples/csmsc/jets/run.sh new file mode 100755 index 00000000..d0985c50 --- /dev/null +++ b/examples/csmsc/jets/run.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_150000.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path}|| exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # synthesize_e2e + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1 +fi + diff --git a/examples/csmsc/vits/local/lite_predict.sh b/examples/csmsc/vits/local/lite_predict.sh index 9ed57b72..e12f5349 100755 --- a/examples/csmsc/vits/local/lite_predict.sh +++ b/examples/csmsc/vits/local/lite_predict.sh @@ -7,7 +7,7 @@ stage=0 stop_stage=0 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - python3 ${BIN_DIR}/../lite_predict.py \ + python3 ${BIN_DIR}/lite_predict.py \ --inference_dir=${train_output_path}/pdlite \ --am=vits_csmsc \ --text=${BIN_DIR}/../sentences.txt \ diff --git a/examples/csmsc/vits/run.sh b/examples/csmsc/vits/run.sh index 03c59702..f6e8a086 100755 --- a/examples/csmsc/vits/run.sh +++ b/examples/csmsc/vits/run.sh @@ -54,16 +54,16 @@ fi # ./local/ort_predict.sh ${train_output_path} # fi -# # not ready yet for operator missing in Paddle-Lite -# # must run after stage 3 (which stage generated static models) -# if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then -# # NOTE by yuantian 2022.11.21: please compile develop version of Paddle-Lite to export and run TTS models, -# # cause TTS models are supported by https://github.com/PaddlePaddle/Paddle-Lite/pull/9587 -# # and https://github.com/PaddlePaddle/Paddle-Lite/pull/9706 -# ./local/export2lite.sh ${train_output_path} inference pdlite vits_csmsc x86 -# fi +# not ready yet for operator missing in Paddle-Lite +# must run after stage 3 (which stage generated static models) +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + # NOTE by yuantian 2022.11.21: please compile develop version of Paddle-Lite to export and run TTS models, + # cause TTS models are supported by https://github.com/PaddlePaddle/Paddle-Lite/pull/10128 + # vits can only run in arm + ./local/export2lite.sh ${train_output_path} inference pdlite vits_csmsc arm +fi -# if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then -# CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict.sh ${train_output_path} || exit -1 -# fi +if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict.sh ${train_output_path} || exit -1 +fi diff --git a/examples/csmsc/voc5/conf/iSTFT.yaml b/examples/csmsc/voc5/conf/iSTFT.yaml new file mode 100644 index 00000000..06677d79 --- /dev/null +++ b/examples/csmsc/voc5/conf/iSTFT.yaml @@ -0,0 +1,174 @@ +# This is the configuration file for CSMSC dataset. +# This configuration is based on HiFiGAN V1, which is an official configuration. +# But I found that the optimizer setting does not work well with my implementation. +# So I changed optimizer settings as follows: +# - AdamW -> Adam +# - betas: [0.8, 0.99] -> betas: [0.5, 0.9] +# - Scheduler: ExponentialLR -> MultiStepLR +# To match the shift size difference, the upsample scales is also modified from the original 256 shift setting. + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 2048 # FFT size (samples). +n_shift: 300 # Hop size (samples). 12.5ms +win_length: 1200 # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + use_istft: True # Use iSTFTNet. + istft_layer_id: 2 # Use istft after istft_layer_id layers of upsample layer if use_istft=True. + n_fft: 2048 # FFT size (samples) in feature extraction. + win_length: 1200 # Window length (samples) in feature extraction. + in_channels: 80 # Number of input channels. + out_channels: 1 # Number of output channels. + channels: 512 # Number of initial channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + upsample_scales: [5, 5, 4, 3] # Upsampling scales. + upsample_kernel_sizes: [10, 10, 8, 6] # Kernel size for upsampling layers. + resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks. + resblock_dilations: # Dilations for residual blocks. + - [1, 3, 5] + - [1, 3, 5] + - [1, 3, 5] + use_additional_convs: True # Whether to use additional conv layer in residual blocks. + bias: True # Whether to use bias parameter in conv. + nonlinear_activation: "leakyrelu" # Nonlinear activation type. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + + + + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + scales: 3 # Number of multi-scale discriminator. + scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator. + scale_downsample_pooling_params: + kernel_size: 4 # Pooling kernel size. + stride: 2 # Pooling stride. + padding: 2 # Padding size. + scale_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. + channels: 128 # Initial number of channels. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + max_groups: 16 # Maximum number of groups in downsampling conv layers. + bias: True + downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: + negative_slope: 0.1 + follow_official_norm: True # Whether to follow the official norm setting. + periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. + period_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [5, 3] # List of kernel sizes. + channels: 32 # Initial number of channels. + downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + bias: True # Whether to use bias parameter in conv layer." + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + use_spectral_norm: False # Whether to apply spectral normalization. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: False # Whether to use multi-resolution STFT loss. +use_mel_loss: True # Whether to use Mel-spectrogram loss. +mel_loss_params: + fs: 24000 + fft_size: 2048 + hop_size: 300 + win_length: 1200 + window: "hann" + num_mels: 80 + fmin: 0 + fmax: 12000 + log_base: null +generator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +discriminator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +use_feat_match_loss: True +feat_match_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. + average_by_layers: False # Whether to average loss by #layers in each discriminator. + include_final_outputs: False # Whether to include final outputs in feat match loss calculation. + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_aux: 45.0 # Loss balancing coefficient for STFT loss. +lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss. +lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss.. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 16 # Batch size. +batch_max_steps: 8400 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 2 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Generator's weight decay coefficient. +generator_scheduler_params: + learning_rate: 2.0e-4 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +generator_grad_norm: -1 # Generator's gradient norm. +discriminator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Discriminator's weight decay coefficient. +discriminator_scheduler_params: + learning_rate: 2.0e-4 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +discriminator_grad_norm: -1 # Discriminator's gradient norm. + +########################################################### +# INTERVAL SETTING # +########################################################### +generator_train_start_steps: 1 # Number of steps to start to train discriminator. +discriminator_train_start_steps: 0 # Number of steps to start to train discriminator. +train_max_steps: 2500000 # Number of training steps. +save_interval_steps: 5000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random diff --git a/examples/csmsc/voc5/iSTFTNet.md b/examples/csmsc/voc5/iSTFTNet.md new file mode 100644 index 00000000..8f121938 --- /dev/null +++ b/examples/csmsc/voc5/iSTFTNet.md @@ -0,0 +1,145 @@ +# iSTFTNet with CSMSC + +This example contains code used to train a [iSTFTNet](https://arxiv.org/abs/2203.02395) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). + +## Dataset +### Download and Extract +Download CSMSC from it's [official website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. + +### Get MFA Result and Extract +We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence at the edge of audio. +You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo. + +## Get Started +Assume the path to the dataset is `~/datasets/BZNSYP`. +Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`. +Run the command below to +1. **source path**. +2. preprocess the dataset. +3. train the model. +4. synthesize wavs. + - synthesize waveform from `metadata.jsonl`. +```bash +./run.sh +``` +You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset. +```bash +./run.sh --stage 0 --stop-stage 0 +``` +### Data Preprocessing +```bash +./local/preprocess.sh ${conf_path} +``` +When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below. + +```text +dump +├── dev +│ ├── norm +│ └── raw +├── test +│ ├── norm +│ └── raw +└── train + ├── norm + ├── raw + └── feats_stats.npy +``` +The dataset is split into 3 parts, namely `train`, `dev`, and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains the log magnitude of the mel spectrogram of each utterance, while the norm folder contains the normalized spectrogram. The statistics used to normalize the spectrogram are computed from the training set, which is located in `dump/train/feats_stats.npy`. + +Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains id and paths to the spectrogram of each utterance. + +### Model Training +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} +``` +`./local/train.sh` calls `${BIN_DIR}/train.py`. +Here's the complete help message. + +```text +usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] + [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] + [--ngpu NGPU] + +Train a HiFiGAN model. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG HiFiGAN config file. + --train-metadata TRAIN_METADATA + training data. + --dev-metadata DEV_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. +``` + +1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/iSTFT.yaml`. +2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder. +3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory. +4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. + +### Synthesizing +`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +```text +usage: synthesize.py [-h] [--generator-type GENERATOR_TYPE] [--config CONFIG] + [--checkpoint CHECKPOINT] [--test-metadata TEST_METADATA] + [--output-dir OUTPUT_DIR] [--ngpu NGPU] + +Synthesize with GANVocoder. + +optional arguments: + -h, --help show this help message and exit + --generator-type GENERATOR_TYPE + type of GANVocoder, should in {pwgan, mb_melgan, + style_melgan, } now + --config CONFIG GANVocoder config file. + --checkpoint CHECKPOINT + snapshot to load. + --test-metadata TEST_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. +``` + +1. `--config` config file. You should use the same config with which the model is trained. +2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory. +3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory. +4. `--output-dir` is the directory to save the synthesized audio files. +5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. + +## Pretrained Models + +The pretrained model can be downloaded here: + +- [iSTFTNet_csmsc_ckpt.zip](https://pan.baidu.com/s/1SNDlRWOGOcbbrKf5w-TJaA?pwd=r1e5) + +iSTFTNet checkpoint contains files listed below. + +```text +iSTFTNet_csmsc_ckpt +├── iSTFT.yaml                  # config used to train iSTFTNet +├── feats_stats.npy               # statistics used to normalize spectrogram when training hifigan +└── snapshot_iter_50000.pdz     # generator parameters of hifigan +``` + +A Comparison between iSTFTNet and Hifigan +| Model | Step | eval/generator_loss | eval/mel_loss | eval/feature_matching_loss | rtf | +|:--------:|:--------------:|:-------------------:|:-------------:|:--------------------------:| :---: | +| hifigan | 1(gpu) x 50000 | 13.989 | 0.14683 | 1.3484 | 0.01767 | +| istftNet | 1(gpu) x 50000 | 13.319 | 0.14818 | 1.1069 | 0.01069 | + +> Rtf is tested on the CSMSC test dataset, and the test environment is aistudio v100 16G 1GPU, the test command is `./run.sh --stage 2 --stop-stage 2` + +The pretained hifigan model int the comparison can be downloaded here: + +- [hifigan_csmsc_ckpt.zip](https://pan.baidu.com/s/1pGY6RYV7yEB_5hRI_JoWig?pwd=tcaj) + +## Acknowledgement + +We adapted some code from https://github.com/rishikksh20/iSTFTNet-pytorch.git. diff --git a/examples/librispeech/asr2/README.md b/examples/librispeech/asr2/README.md index 26978520..253c9b45 100644 --- a/examples/librispeech/asr2/README.md +++ b/examples/librispeech/asr2/README.md @@ -153,7 +153,7 @@ After training the model, we need to get the final model for testing and inferen ```bash if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - avg.sh lastest exp/${ckpt}/checkpoints ${avg_num} + avg.sh latest exp/${ckpt}/checkpoints ${avg_num} fi ``` The `avg.sh` is in the `../../../utils/` which is define in the `path.sh`. diff --git a/examples/other/mfa/local/generate_lexicon.py b/examples/other/mfa/local/generate_lexicon.py index 3deb2470..e63b5eb2 100644 --- a/examples/other/mfa/local/generate_lexicon.py +++ b/examples/other/mfa/local/generate_lexicon.py @@ -48,7 +48,7 @@ def rule(C, V, R, T): 'i' is distinguished when appeared in phonemes, and separated into 3 categories, 'i', 'ii' and 'iii'. - Erhua is is possibly applied to every finals, except for finals that already ends with 'r'. + Erhua is possibly applied to every finals, except for finals that already ends with 'r'. When a syllable is impossible or does not have any characters with this pronunciation, return None to filter it out. diff --git a/examples/tiny/asr1/README.md b/examples/tiny/asr1/README.md index cfa26670..489f5bc3 100644 --- a/examples/tiny/asr1/README.md +++ b/examples/tiny/asr1/README.md @@ -37,7 +37,7 @@ It will support the way of using `--variable value` in the shell scripts. Some local variables are set in `run.sh`. `gpus` denotes the GPU number you want to use. If you set `gpus=`, it means you only use CPU. `stage` denotes the number of stage you want the start from in the experiments. -`stop stage` denotes the number of stage you want the stop at in the expriments. +`stop stage` denotes the number of stage you want the stop at in the experiments. `conf_path` denotes the config path of the model. `avg_num`denotes the number K of top-K models you want to average to get the final model. `ckpt` denotes the checkpoint prefix of the model, e.g. "transformerr" diff --git a/examples/vctk/vc3/conf/default.yaml b/examples/vctk/vc3/conf/default.yaml index 0acc2a56..eb98515a 100644 --- a/examples/vctk/vc3/conf/default.yaml +++ b/examples/vctk/vc3/conf/default.yaml @@ -1,22 +1,135 @@ - generator_params: +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +# 源码 load 的时候用的 24k, 提取 mel 用的 16k, 后续 load 和提取 mel 都要改成 24k +fs: 16000 +n_fft: 2048 +n_shift: 300 +win_length: 1200 # Window length.(in samples) 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. + +fmin: 0 # Minimum frequency of Mel basis. +fmax: 8000 # Maximum frequency of Mel basis. sr // 2 +n_mels: 80 +# only for StarGANv2 VC +norm: # None here +htk: True +power: 2.0 + + +########################################################### +# MODEL SETTING # +########################################################### +generator_params: dim_in: 64 style_dim: 64 max_conv_dim: 512 w_hpf: 0 F0_channel: 256 - mapping_network_params: +mapping_network_params: num_domains: 20 # num of speakers in StarGANv2 latent_dim: 16 style_dim: 64 # same as style_dim in generator_params hidden_dim: 512 # same as max_conv_dim in generator_params - style_encoder_params: +style_encoder_params: dim_in: 64 # same as dim_in in generator_params style_dim: 64 # same as style_dim in generator_params num_domains: 20 # same as num_domains in generator_params max_conv_dim: 512 # same as max_conv_dim in generator_params - discriminator_params: +discriminator_params: dim_in: 64 # same as dim_in in generator_params num_domains: 20 # same as num_domains in mapping_network_params max_conv_dim: 512 # same as max_conv_dim in generator_params - n_repeat: 4 - \ No newline at end of file + repeat_num: 4 +asr_params: + input_dim: 80 + hidden_dim: 256 + n_token: 80 + token_embedding_dim: 256 + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +loss_params: + g_loss: + lambda_sty: 1. + lambda_cyc: 5. + lambda_ds: 1. + lambda_norm: 1. + lambda_asr: 10. + lambda_f0: 5. + lambda_f0_sty: 0.1 + lambda_adv: 2. + lambda_adv_cls: 0.5 + norm_bias: 0.5 + d_loss: + lambda_reg: 1. + lambda_adv_cls: 0.1 + lambda_con_reg: 10. + + adv_cls_epoch: 50 + con_reg_epoch: 30 + + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 5 # Batch size. +num_workers: 2 # Number of workers in DataLoader. +max_mel_length: 192 + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + beta1: 0.0 + beta2: 0.99 + weight_decay: 1.0e-4 + epsilon: 1.0e-9 +generator_scheduler_params: + max_learning_rate: 2.0e-4 + phase_pct: 0.0 + divide_factor: 1 + total_steps: 200000 # train_max_steps + end_learning_rate: 2.0e-4 +style_encoder_optimizer_params: + beta1: 0.0 + beta2: 0.99 + weight_decay: 1.0e-4 + epsilon: 1.0e-9 +style_encoder_scheduler_params: + max_learning_rate: 2.0e-4 + phase_pct: 0.0 + divide_factor: 1 + total_steps: 200000 # train_max_steps + end_learning_rate: 2.0e-4 +mapping_network_optimizer_params: + beta1: 0.0 + beta2: 0.99 + weight_decay: 1.0e-4 + epsilon: 1.0e-9 +mapping_network_scheduler_params: + max_learning_rate: 2.0e-6 + phase_pct: 0.0 + divide_factor: 1 + total_steps: 200000 # train_max_steps + end_learning_rate: 2.0e-6 +discriminator_optimizer_params: + beta1: 0.0 + beta2: 0.99 + weight_decay: 1.0e-4 + epsilon: 1.0e-9 +discriminator_scheduler_params: + max_learning_rate: 2.0e-4 + phase_pct: 0.0 + divide_factor: 1 + total_steps: 200000 # train_max_steps + end_learning_rate: 2.0e-4 + +########################################################### +# TRAINING SETTING # +########################################################### +max_epoch: 150 +num_snapshots: 5 +seed: 1 \ No newline at end of file diff --git a/examples/vctk/vc3/local/preprocess.sh b/examples/vctk/vc3/local/preprocess.sh index ea0fbc43..058171c5 100755 --- a/examples/vctk/vc3/local/preprocess.sh +++ b/examples/vctk/vc3/local/preprocess.sh @@ -6,13 +6,32 @@ stop_stage=100 config_path=$1 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=vctk \ + --rootdir=~/datasets/VCTK-Corpus-0.92/ \ + --dumpdir=dump \ + --config=${config_path} \ + --num-cpu=20 fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - -fi - -if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then +echo "Normalize ..." + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --speaker-dict=dump/speaker_id_map.txt fi diff --git a/examples/vctk/vc3/local/train.sh b/examples/vctk/vc3/local/train.sh index 3a507650..d4ea02da 100755 --- a/examples/vctk/vc3/local/train.sh +++ b/examples/vctk/vc3/local/train.sh @@ -9,5 +9,4 @@ python3 ${BIN_DIR}/train.py \ --config=${config_path} \ --output-dir=${train_output_path} \ --ngpu=1 \ - --phones-dict=dump/phone_id_map.txt \ --speaker-dict=dump/speaker_id_map.txt diff --git a/paddlespeech/__init__.py b/paddlespeech/__init__.py index 6c7e75c1..969d189f 100644 --- a/paddlespeech/__init__.py +++ b/paddlespeech/__init__.py @@ -13,3 +13,7 @@ # limitations under the License. import _locale _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) + +__version__ = '0.0.0' + +__commit__ = '9cf8c1985a98bb380c183116123672976bdfe5c9' diff --git a/paddlespeech/cli/download.py b/paddlespeech/cli/download.py index 5661f18f..e77a05d2 100644 --- a/paddlespeech/cli/download.py +++ b/paddlespeech/cli/download.py @@ -133,10 +133,10 @@ def _get_download(url, fullname): total_size = req.headers.get('content-length') with open(tmp_fullname, 'wb') as f: if total_size: - with tqdm(total=(int(total_size) + 1023) // 1024) as pbar: + with tqdm(total=(int(total_size)), unit='B', unit_scale=True) as pbar: for chunk in req.iter_content(chunk_size=1024): f.write(chunk) - pbar.update(1) + pbar.update(len(chunk)) else: for chunk in req.iter_content(chunk_size=1024): if chunk: diff --git a/speechx/docker/.gitkeep b/paddlespeech/dataset/__init__.py similarity index 100% rename from speechx/docker/.gitkeep rename to paddlespeech/dataset/__init__.py diff --git a/dataset/aidatatang_200zh/README.md b/paddlespeech/dataset/aidatatang_200zh/README.md similarity index 100% rename from dataset/aidatatang_200zh/README.md rename to paddlespeech/dataset/aidatatang_200zh/README.md diff --git a/paddlespeech/dataset/aidatatang_200zh/__init__.py b/paddlespeech/dataset/aidatatang_200zh/__init__.py new file mode 100644 index 00000000..9146247d --- /dev/null +++ b/paddlespeech/dataset/aidatatang_200zh/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .aidatatang_200zh import main as aidatatang_200zh_main diff --git a/paddlespeech/dataset/aidatatang_200zh/aidatatang_200zh.py b/paddlespeech/dataset/aidatatang_200zh/aidatatang_200zh.py new file mode 100644 index 00000000..5d914a43 --- /dev/null +++ b/paddlespeech/dataset/aidatatang_200zh/aidatatang_200zh.py @@ -0,0 +1,158 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prepare aidatatang_200zh mandarin dataset + +Download, unpack and create manifest files. +Manifest file is a json-format file with each line containing the +meta data (i.e. audio filepath, transcript and audio duration) +of each audio file in the data set. +""" +import argparse +import codecs +import json +import os +from pathlib import Path + +import soundfile + +from paddlespeech.dataset.download import download +from paddlespeech.dataset.download import unpack +from paddlespeech.utils.argparse import print_arguments + +DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') + +URL_ROOT = 'http://www.openslr.org/resources/62' +# URL_ROOT = 'https://openslr.magicdatatech.com/resources/62' +DATA_URL = URL_ROOT + '/aidatatang_200zh.tgz' +MD5_DATA = '6e0f4f39cd5f667a7ee53c397c8d0949' + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--target_dir", + default=DATA_HOME + "/aidatatang_200zh", + type=str, + help="Directory to save the dataset. (default: %(default)s)") +parser.add_argument( + "--manifest_prefix", + default="manifest", + type=str, + help="Filepath prefix for output manifests. (default: %(default)s)") +args = parser.parse_args() + + +def create_manifest(data_dir, manifest_path_prefix): + print("Creating manifest %s ..." % manifest_path_prefix) + json_lines = [] + transcript_path = os.path.join(data_dir, 'transcript', + 'aidatatang_200_zh_transcript.txt') + transcript_dict = {} + for line in codecs.open(transcript_path, 'r', 'utf-8'): + line = line.strip() + if line == '': + continue + audio_id, text = line.split(' ', 1) + # remove withespace, charactor text + text = ''.join(text.split()) + transcript_dict[audio_id] = text + + data_types = ['train', 'dev', 'test'] + for dtype in data_types: + del json_lines[:] + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + + audio_dir = os.path.join(data_dir, 'corpus/', dtype) + for subfolder, _, filelist in sorted(os.walk(audio_dir)): + for fname in filelist: + if not fname.endswith('.wav'): + continue + + audio_path = os.path.abspath(os.path.join(subfolder, fname)) + audio_id = os.path.basename(fname)[:-4] + utt2spk = Path(audio_path).parent.name + + audio_data, samplerate = soundfile.read(audio_path) + duration = float(len(audio_data) / samplerate) + text = transcript_dict[audio_id] + json_lines.append( + json.dumps( + { + 'utt': audio_id, + 'utt2spk': str(utt2spk), + 'feat': audio_path, + 'feat_shape': (duration, ), # second + 'text': text, + }, + ensure_ascii=False)) + + total_sec += duration + total_text += len(text) + total_num += 1 + + manifest_path = manifest_path_prefix + '.' + dtype + with codecs.open(manifest_path, 'w', 'utf-8') as fout: + for line in json_lines: + fout.write(line + '\n') + + manifest_dir = os.path.dirname(manifest_path_prefix) + meta_path = os.path.join(manifest_dir, dtype) + '.meta' + with open(meta_path, 'w') as f: + print(f"{dtype}:", file=f) + print(f"{total_num} utts", file=f) + print(f"{total_sec / (60*60)} h", file=f) + print(f"{total_text} text", file=f) + print(f"{total_text / total_sec} text/sec", file=f) + print(f"{total_sec / total_num} sec/utt", file=f) + + +def prepare_dataset(url, md5sum, target_dir, manifest_path, subset): + """Download, unpack and create manifest file.""" + data_dir = os.path.join(target_dir, subset) + if not os.path.exists(data_dir): + filepath = download(url, md5sum, target_dir) + unpack(filepath, target_dir) + # unpack all audio tar files + audio_dir = os.path.join(data_dir, 'corpus') + for subfolder, dirlist, filelist in sorted(os.walk(audio_dir)): + for sub in dirlist: + print(f"unpack dir {sub}...") + for folder, _, filelist in sorted( + os.walk(os.path.join(subfolder, sub))): + for ftar in filelist: + unpack(os.path.join(folder, ftar), folder, True) + else: + print("Skip downloading and unpacking. Data already exists in %s." % + target_dir) + + create_manifest(data_dir, manifest_path) + + +def main(): + print_arguments(args, globals()) + if args.target_dir.startswith('~'): + args.target_dir = os.path.expanduser(args.target_dir) + + prepare_dataset( + url=DATA_URL, + md5sum=MD5_DATA, + target_dir=args.target_dir, + manifest_path=args.manifest_prefix, + subset='aidatatang_200zh') + + print("Data download and manifest prepare done!") + + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/aishell/README.md b/paddlespeech/dataset/aishell/README.md new file mode 100644 index 00000000..c46312df --- /dev/null +++ b/paddlespeech/dataset/aishell/README.md @@ -0,0 +1,58 @@ +# [Aishell1](http://openslr.elda.org/33/) + +This Open Source Mandarin Speech Corpus, AISHELL-ASR0009-OS1, is 178 hours long. It is a part of AISHELL-ASR0009, of which utterance contains 11 domains, including smart home, autonomous driving, and industrial production. The whole recording was put in quiet indoor environment, using 3 different devices at the same time: high fidelity microphone (44.1kHz, 16-bit,); Android-system mobile phone (16kHz, 16-bit), iOS-system mobile phone (16kHz, 16-bit). Audios in high fidelity were re-sampled to 16kHz to build AISHELL- ASR0009-OS1. 400 speakers from different accent areas in China were invited to participate in the recording. The manual transcription accuracy rate is above 95%, through professional speech annotation and strict quality inspection. The corpus is divided into training, development and testing sets. ( This database is free for academic research, not in the commerce, if without permission. ) + + +## Dataset Architecture + +```bash +data_aishell +├── transcript # text 目录 +└── wav # wav 目录 + ├── dev # dev 目录 + │ ├── S0724 # spk 目录 + │ ├── S0725 + │ ├── S0726 + ├── train + │ ├── S0724 + │ ├── S0725 + │ ├── S0726 + ├── test + │ ├── S0724 + │ ├── S0725 + │ ├── S0726 + + +data_aishell +├── transcript +│ └── aishell_transcript_v0.8.txt # 文本标注文件 +└── wav + ├── dev + │ ├── S0724 + │ │ ├── BAC009S0724W0121.wav # S0724 的音频 + │ │ ├── BAC009S0724W0122.wav + │ │ ├── BAC009S0724W0123.wav + ├── test + │ ├── S0724 + │ │ ├── BAC009S0724W0121.wav + │ │ ├── BAC009S0724W0122.wav + │ │ ├── BAC009S0724W0123.wav + ├── train + │ ├── S0724 + │ │ ├── BAC009S0724W0121.wav + │ │ ├── BAC009S0724W0122.wav + │ │ ├── BAC009S0724W0123.wav + +标注文件格式: +> head data_aishell/transcript/aishell_transcript_v0.8.txt +BAC009S0002W0122 而 对 楼市 成交 抑制 作用 最 大 的 限 购 +BAC009S0002W0123 也 成为 地方 政府 的 眼中 钉 +BAC009S0002W0124 自 六月 底 呼和浩特 市 率先 宣布 取消 限 购 后 +BAC009S0002W0125 各地 政府 便 纷纷 跟进 +BAC009S0002W0126 仅 一 个 多 月 的 时间 里 +BAC009S0002W0127 除了 北京 上海 广州 深圳 四 个 一 线 城市 和 三亚 之外 +BAC009S0002W0128 四十六 个 限 购 城市 当中 +BAC009S0002W0129 四十一 个 已 正式 取消 或 变相 放松 了 限 购 +BAC009S0002W0130 财政 金融 政策 紧随 其后 而来 +BAC009S0002W0131 显示 出 了 极 强 的 威力 +``` diff --git a/paddlespeech/dataset/aishell/__init__.py b/paddlespeech/dataset/aishell/__init__.py new file mode 100644 index 00000000..667680af --- /dev/null +++ b/paddlespeech/dataset/aishell/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .aishell import check_dataset +from .aishell import create_manifest +from .aishell import download_dataset +from .aishell import main as aishell_main +from .aishell import prepare_dataset diff --git a/paddlespeech/dataset/aishell/aishell.py b/paddlespeech/dataset/aishell/aishell.py new file mode 100644 index 00000000..7ea4d676 --- /dev/null +++ b/paddlespeech/dataset/aishell/aishell.py @@ -0,0 +1,230 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prepare Aishell mandarin dataset + +Download, unpack and create manifest files. +Manifest file is a json-format file with each line containing the +meta data (i.e. audio filepath, transcript and audio duration) +of each audio file in the data set. +""" +import argparse +import codecs +import json +import os +from pathlib import Path + +import soundfile + +from paddlespeech.dataset.download import download +from paddlespeech.dataset.download import unpack +from paddlespeech.utils.argparse import print_arguments + +DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') + +URL_ROOT = 'http://openslr.elda.org/resources/33' +# URL_ROOT = 'https://openslr.magicdatatech.com/resources/33' +DATA_URL = URL_ROOT + '/data_aishell.tgz' +MD5_DATA = '2f494334227864a8a8fec932999db9d8' +RESOURCE_URL = URL_ROOT + '/resource_aishell.tgz' +MD5_RESOURCE = '957d480a0fcac85fc18e550756f624e5' + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--target_dir", + default=DATA_HOME + "/Aishell", + type=str, + help="Directory to save the dataset. (default: %(default)s)") +parser.add_argument( + "--manifest_prefix", + default="manifest", + type=str, + help="Filepath prefix for output manifests. (default: %(default)s)") +args = parser.parse_args() + + +def create_manifest(data_dir, manifest_path_prefix): + print("Creating manifest %s ..." % os.path.join(data_dir, + manifest_path_prefix)) + json_lines = [] + transcript_path = os.path.join(data_dir, 'transcript', + 'aishell_transcript_v0.8.txt') + transcript_dict = {} + for line in codecs.open(transcript_path, 'r', 'utf-8'): + line = line.strip() + if line == '': + continue + audio_id, text = line.split(' ', 1) + # remove withespace, charactor text + text = ''.join(text.split()) + transcript_dict[audio_id] = text + + data_metas = dict() + data_types = ['train', 'dev', 'test'] + for dtype in data_types: + del json_lines[:] + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + + audio_dir = os.path.join(data_dir, 'wav', dtype) + for subfolder, _, filelist in sorted(os.walk(audio_dir)): + for fname in filelist: + audio_path = os.path.abspath(os.path.join(subfolder, fname)) + audio_id = os.path.basename(fname)[:-4] + # if no transcription for audio then skipped + if audio_id not in transcript_dict: + continue + + utt2spk = Path(audio_path).parent.name + audio_data, samplerate = soundfile.read(audio_path) + duration = float(len(audio_data) / samplerate) + text = transcript_dict[audio_id] + json_lines.append( + json.dumps( + { + 'utt': audio_id, + 'utt2spk': str(utt2spk), + 'feat': audio_path, + 'feat_shape': (duration, ), # second + 'text': text + }, + ensure_ascii=False)) + + total_sec += duration + total_text += len(text) + total_num += 1 + + manifest_path = manifest_path_prefix + '.' + dtype + with codecs.open(manifest_path, 'w', 'utf-8') as fout: + for line in json_lines: + fout.write(line + '\n') + + meta = dict() + meta["dtype"] = dtype # train, dev, test + meta["utts"] = total_num + meta["hours"] = total_sec / (60 * 60) + meta["text"] = total_text + meta["text/sec"] = total_text / total_sec + meta["sec/utt"] = total_sec / total_num + data_metas[dtype] = meta + + manifest_dir = os.path.dirname(manifest_path_prefix) + meta_path = os.path.join(manifest_dir, dtype) + '.meta' + with open(meta_path, 'w') as f: + for key, val in meta.items(): + print(f"{key}: {val}", file=f) + + return data_metas + + +def download_dataset(url, md5sum, target_dir): + """Download, unpack and create manifest file.""" + data_dir = os.path.join(target_dir, 'data_aishell') + if not os.path.exists(data_dir): + filepath = download(url, md5sum, target_dir) + unpack(filepath, target_dir) + # unpack all audio tar files + audio_dir = os.path.join(data_dir, 'wav') + for subfolder, _, filelist in sorted(os.walk(audio_dir)): + for ftar in filelist: + unpack(os.path.join(subfolder, ftar), subfolder, True) + else: + print("Skip downloading and unpacking. Data already exists in %s." % + os.path.abspath(target_dir)) + return os.path.abspath(data_dir) + + +def check_dataset(data_dir): + print(f"check dataset {os.path.abspath(data_dir)} ...") + + transcript_path = os.path.join(data_dir, 'transcript', + 'aishell_transcript_v0.8.txt') + if not os.path.exists(transcript_path): + raise FileNotFoundError(f"no transcript file found in {data_dir}.") + + transcript_dict = {} + for line in codecs.open(transcript_path, 'r', 'utf-8'): + line = line.strip() + if line == '': + continue + audio_id, text = line.split(' ', 1) + # remove withespace, charactor text + text = ''.join(text.split()) + transcript_dict[audio_id] = text + + no_label = 0 + data_types = ['train', 'dev', 'test'] + for dtype in data_types: + audio_dir = os.path.join(data_dir, 'wav', dtype) + if not os.path.exists(audio_dir): + raise IOError(f"{audio_dir} does not exist.") + + for subfolder, _, filelist in sorted(os.walk(audio_dir)): + for fname in filelist: + audio_path = os.path.abspath(os.path.join(subfolder, fname)) + audio_id = os.path.basename(fname)[:-4] + # if no transcription for audio then skipped + if audio_id not in transcript_dict: + print(f"Warning: {audio_id} not has transcript.") + no_label += 1 + continue + + utt2spk = Path(audio_path).parent.name + audio_data, samplerate = soundfile.read(audio_path) + assert samplerate == 16000, f"{audio_path} sample rate is {samplerate} not 16k, please check." + + print(f"Warning: {dtype} has {no_label} audio does not has transcript.") + + +def prepare_dataset(url, md5sum, target_dir, manifest_path=None, check=False): + """Download, unpack and create manifest file.""" + data_dir = download_dataset(url, md5sum, target_dir) + + if check: + try: + check_dataset(data_dir) + except Exception as e: + raise ValueError( + f"{data_dir} dataset format not right, please check it.") + + meta = None + if manifest_path: + meta = create_manifest(data_dir, manifest_path) + + return data_dir, meta + + +def main(): + print_arguments(args, globals()) + if args.target_dir.startswith('~'): + args.target_dir = os.path.expanduser(args.target_dir) + + data_dir, meta = prepare_dataset( + url=DATA_URL, + md5sum=MD5_DATA, + target_dir=args.target_dir, + manifest_path=args.manifest_prefix, + check=True) + + resource_dir, _ = prepare_dataset( + url=RESOURCE_URL, + md5sum=MD5_RESOURCE, + target_dir=args.target_dir, + manifest_path=None) + + print("Data download and manifest prepare done!") + + +if __name__ == '__main__': + main() diff --git a/utils/utility.py b/paddlespeech/dataset/download.py similarity index 59% rename from utils/utility.py rename to paddlespeech/dataset/download.py index dbf8b1d7..28dbd0eb 100755 --- a/utils/utility.py +++ b/paddlespeech/dataset/download.py @@ -19,91 +19,16 @@ import zipfile from typing import Text __all__ = [ - "check_md5sum", "getfile_insensitive", "download_multi", "download", - "unpack", "unzip", "md5file", "print_arguments", "add_arguments", - "get_commandline_args" + "check_md5sum", + "getfile_insensitive", + "download_multi", + "download", + "unpack", + "unzip", + "md5file", ] -def get_commandline_args(): - extra_chars = [ - " ", - ";", - "&", - "(", - ")", - "|", - "^", - "<", - ">", - "?", - "*", - "[", - "]", - "$", - "`", - '"', - "\\", - "!", - "{", - "}", - ] - - # Escape the extra characters for shell - argv = [ - arg.replace("'", "'\\''") if all(char not in arg - for char in extra_chars) else - "'" + arg.replace("'", "'\\''") + "'" for arg in sys.argv - ] - - return sys.executable + " " + " ".join(argv) - - -def print_arguments(args, info=None): - """Print argparse's arguments. - - Usage: - - .. code-block:: python - - parser = argparse.ArgumentParser() - parser.add_argument("name", default="Jonh", type=str, help="User name.") - args = parser.parse_args() - print_arguments(args) - - :param args: Input argparse.Namespace for printing. - :type args: argparse.Namespace - """ - filename = "" - if info: - filename = info["__file__"] - filename = os.path.basename(filename) - print(f"----------- {filename} Configuration Arguments -----------") - for arg, value in sorted(vars(args).items()): - print("%s: %s" % (arg, value)) - print("-----------------------------------------------------------") - - -def add_arguments(argname, type, default, help, argparser, **kwargs): - """Add argparse's argument. - - Usage: - - .. code-block:: python - - parser = argparse.ArgumentParser() - add_argument("name", str, "Jonh", "User name.", parser) - args = parser.parse_args() - """ - type = distutils.util.strtobool if type == bool else type - argparser.add_argument( - "--" + argname, - default=default, - type=type, - help=help + ' Default: %(default)s.', - **kwargs) - - def md5file(fname): hash_md5 = hashlib.md5() f = open(fname, "rb") diff --git a/paddlespeech/dataset/s2t/__init__.py b/paddlespeech/dataset/s2t/__init__.py new file mode 100644 index 00000000..27ea9e77 --- /dev/null +++ b/paddlespeech/dataset/s2t/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# s2t utils binaries. +from .avg_model import main as avg_ckpts_main +from .build_vocab import main as build_vocab_main +from .compute_mean_std import main as compute_mean_std_main +from .compute_wer import main as compute_wer_main +from .format_data import main as format_data_main +from .format_rsl import main as format_rsl_main diff --git a/paddlespeech/dataset/s2t/avg_model.py b/paddlespeech/dataset/s2t/avg_model.py new file mode 100755 index 00000000..c5753b72 --- /dev/null +++ b/paddlespeech/dataset/s2t/avg_model.py @@ -0,0 +1,125 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import glob +import json +import os + +import numpy as np +import paddle + + +def define_argparse(): + parser = argparse.ArgumentParser(description='average model') + parser.add_argument('--dst_model', required=True, help='averaged model') + parser.add_argument( + '--ckpt_dir', required=True, help='ckpt model dir for average') + parser.add_argument( + '--val_best', action="store_true", help='averaged model') + parser.add_argument( + '--num', default=5, type=int, help='nums for averaged model') + parser.add_argument( + '--min_epoch', + default=0, + type=int, + help='min epoch used for averaging model') + parser.add_argument( + '--max_epoch', + default=65536, # Big enough + type=int, + help='max epoch used for averaging model') + + args = parser.parse_args() + return args + + +def average_checkpoints(dst_model="", + ckpt_dir="", + val_best=True, + num=5, + min_epoch=0, + max_epoch=65536): + paddle.set_device('cpu') + + val_scores = [] + jsons = glob.glob(f'{ckpt_dir}/[!train]*.json') + jsons = sorted(jsons, key=os.path.getmtime, reverse=True) + for y in jsons: + with open(y, 'r') as f: + dic_json = json.load(f) + loss = dic_json['val_loss'] + epoch = dic_json['epoch'] + if epoch >= min_epoch and epoch <= max_epoch: + val_scores.append((epoch, loss)) + assert val_scores, f"Not find any valid checkpoints: {val_scores}" + val_scores = np.array(val_scores) + + if val_best: + sort_idx = np.argsort(val_scores[:, 1]) + sorted_val_scores = val_scores[sort_idx] + else: + sorted_val_scores = val_scores + + beat_val_scores = sorted_val_scores[:num, 1] + selected_epochs = sorted_val_scores[:num, 0].astype(np.int64) + avg_val_score = np.mean(beat_val_scores) + print("selected val scores = " + str(beat_val_scores)) + print("selected epochs = " + str(selected_epochs)) + print("averaged val score = " + str(avg_val_score)) + + path_list = [ + ckpt_dir + '/{}.pdparams'.format(int(epoch)) + for epoch in sorted_val_scores[:num, 0] + ] + print(path_list) + + avg = None + num = args.num + assert num == len(path_list) + for path in path_list: + print(f'Processing {path}') + states = paddle.load(path) + if avg is None: + avg = states + else: + for k in avg.keys(): + avg[k] += states[k] + # average + for k in avg.keys(): + if avg[k] is not None: + avg[k] /= num + + paddle.save(avg, args.dst_model) + print(f'Saving to {args.dst_model}') + + meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' + with open(meta_path, 'w') as f: + data = json.dumps({ + "mode": 'val_best' if args.val_best else 'latest', + "avg_ckpt": args.dst_model, + "val_loss_mean": avg_val_score, + "ckpts": path_list, + "epochs": selected_epochs.tolist(), + "val_losses": beat_val_scores.tolist(), + }) + f.write(data + "\n") + + +def main(): + args = define_argparse() + average_checkpoints(args) + + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/build_vocab.py b/paddlespeech/dataset/s2t/build_vocab.py new file mode 100755 index 00000000..dd5f6208 --- /dev/null +++ b/paddlespeech/dataset/s2t/build_vocab.py @@ -0,0 +1,166 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Build vocabulary from manifest files. +Each item in vocabulary file is a character. +""" +import argparse +import functools +import os +import tempfile +from collections import Counter + +import jsonlines + +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.frontend.utility import BLANK +from paddlespeech.s2t.frontend.utility import SOS +from paddlespeech.s2t.frontend.utility import SPACE +from paddlespeech.s2t.frontend.utility import UNK +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments + + +def count_manifest(counter, text_feature, manifest_path): + manifest_jsons = [] + with jsonlines.open(manifest_path, 'r') as reader: + for json_data in reader: + manifest_jsons.append(json_data) + + for line_json in manifest_jsons: + if isinstance(line_json['text'], str): + tokens = text_feature.tokenize( + line_json['text'], replace_space=False) + + counter.update(tokens) + else: + assert isinstance(line_json['text'], list) + for text in line_json['text']: + tokens = text_feature.tokenize(text, replace_space=False) + counter.update(tokens) + + +def dump_text_manifest(fileobj, manifest_path, key='text'): + manifest_jsons = [] + with jsonlines.open(manifest_path, 'r') as reader: + for json_data in reader: + manifest_jsons.append(json_data) + + for line_json in manifest_jsons: + if isinstance(line_json[key], str): + fileobj.write(line_json[key] + "\n") + else: + assert isinstance(line_json[key], list) + for line in line_json[key]: + fileobj.write(line + "\n") + + +def build_vocab(manifest_paths="", + vocab_path="examples/librispeech/data/vocab.txt", + unit_type="char", + count_threshold=0, + text_keys='text', + spm_mode="unigram", + spm_vocab_size=0, + spm_model_prefix="", + spm_character_coverage=0.9995): + fout = open(vocab_path, 'w', encoding='utf-8') + fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC + fout.write(UNK + '\n') # must be 1 + + if unit_type == 'spm': + # tools/spm_train --input=$wave_data/lang_char/input.txt + # --vocab_size=${nbpe} --model_type=${bpemode} + # --model_prefix=${bpemodel} --input_sentence_size=100000000 + import sentencepiece as spm + + fp = tempfile.NamedTemporaryFile(mode='w', delete=False) + for manifest_path in manifest_paths: + _text_keys = [text_keys] if type( + text_keys) is not list else text_keys + for text_key in _text_keys: + dump_text_manifest(fp, manifest_path, key=text_key) + fp.close() + # train + spm.SentencePieceTrainer.Train( + input=fp.name, + vocab_size=spm_vocab_size, + model_type=spm_mode, + model_prefix=spm_model_prefix, + input_sentence_size=100000000, + character_coverage=spm_character_coverage) + os.unlink(fp.name) + + # encode + text_feature = TextFeaturizer(unit_type, "", spm_model_prefix) + counter = Counter() + + for manifest_path in manifest_paths: + count_manifest(counter, text_feature, manifest_path) + + count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) + tokens = [] + for token, count in count_sorted: + if count < count_threshold: + break + # replace space by `` + token = SPACE if token == ' ' else token + tokens.append(token) + + tokens = sorted(tokens) + for token in tokens: + fout.write(token + '\n') + + fout.write(SOS + "\n") # + fout.close() + + +def define_argparse(): + parser = argparse.ArgumentParser(description=__doc__) + add_arg = functools.partial(add_arguments, argparser=parser) + + # yapf: disable + add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") + add_arg('count_threshold', int, 0, + "Truncation threshold for char/word counts.Default 0, no truncate.") + add_arg('vocab_path', str, + 'examples/librispeech/data/vocab.txt', + "Filepath to write the vocabulary.") + add_arg('manifest_paths', str, + None, + "Filepaths of manifests for building vocabulary. " + "You can provide multiple manifest files.", + nargs='+', + required=True) + add_arg('text_keys', str, + 'text', + "keys of the text in manifest for building vocabulary. " + "You can provide multiple k.", + nargs='+') + # bpe + add_arg('spm_vocab_size', int, 0, "Vocab size for spm.") + add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") + add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm") + add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols") + # yapf: disable + + args = parser.parse_args() + return args + +def main(): + args = define_argparse() + print_arguments(args, globals()) + build_vocab(**vars(args)) + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/compute_mean_std.py b/paddlespeech/dataset/s2t/compute_mean_std.py new file mode 100755 index 00000000..8762ee57 --- /dev/null +++ b/paddlespeech/dataset/s2t/compute_mean_std.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compute mean and std for feature normalizer, and save to file.""" +import argparse +import functools + +from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline +from paddlespeech.s2t.frontend.featurizer.audio_featurizer import AudioFeaturizer +from paddlespeech.s2t.frontend.normalizer import FeatureNormalizer +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments + + +def compute_cmvn(manifest_path="data/librispeech/manifest.train", + output_path="data/librispeech/mean_std.npz", + num_samples=2000, + num_workers=0, + spectrum_type="linear", + feat_dim=13, + delta_delta=False, + stride_ms=10, + window_ms=20, + sample_rate=16000, + use_dB_normalization=True, + target_dB=-20): + + augmentation_pipeline = AugmentationPipeline('{}') + audio_featurizer = AudioFeaturizer( + spectrum_type=spectrum_type, + feat_dim=feat_dim, + delta_delta=delta_delta, + stride_ms=float(stride_ms), + window_ms=float(window_ms), + n_fft=None, + max_freq=None, + target_sample_rate=sample_rate, + use_dB_normalization=use_dB_normalization, + target_dB=target_dB, + dither=0.0) + + def augment_and_featurize(audio_segment): + augmentation_pipeline.transform_audio(audio_segment) + return audio_featurizer.featurize(audio_segment) + + normalizer = FeatureNormalizer( + mean_std_filepath=None, + manifest_path=manifest_path, + featurize_func=augment_and_featurize, + num_samples=num_samples, + num_workers=num_workers) + normalizer.write_to_file(output_path) + + +def define_argparse(): + parser = argparse.ArgumentParser(description=__doc__) + add_arg = functools.partial(add_arguments, argparser=parser) + + # yapf: disable + add_arg('manifest_path', str, + 'data/librispeech/manifest.train', + "Filepath of manifest to compute normalizer's mean and stddev.") + + add_arg('output_path', str, + 'data/librispeech/mean_std.npz', + "Filepath of write mean and stddev to (.npz).") + add_arg('num_samples', int, 2000, "# of samples to for statistics.") + add_arg('num_workers', + default=0, + type=int, + help='num of subprocess workers for processing') + + + add_arg('spectrum_type', str, + 'linear', + "Audio feature type. Options: linear, mfcc, fbank.", + choices=['linear', 'mfcc', 'fbank']) + add_arg('feat_dim', int, 13, "Audio feature dim.") + add_arg('delta_delta', bool, False, "Audio feature with delta delta.") + add_arg('stride_ms', int, 10, "stride length in ms.") + add_arg('window_ms', int, 20, "stride length in ms.") + add_arg('sample_rate', int, 16000, "target sample rate.") + add_arg('use_dB_normalization', bool, True, "do dB normalization.") + add_arg('target_dB', int, -20, "target dB.") + # yapf: disable + + args = parser.parse_args() + return args + +def main(): + args = define_argparse() + print_arguments(args, globals()) + compute_cmvn(**vars(args)) + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/compute_wer.py b/paddlespeech/dataset/s2t/compute_wer.py new file mode 100755 index 00000000..5711c725 --- /dev/null +++ b/paddlespeech/dataset/s2t/compute_wer.py @@ -0,0 +1,558 @@ +# Copyright 2021 Mobvoi Inc. All Rights Reserved. +# flake8: noqa +import codecs +import re +import sys +import unicodedata + +remove_tag = True +spacelist = [' ', '\t', '\r', '\n'] +puncts = [ + '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』', + '《', '》' +] + + +def characterize(string): + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + #https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == 'Lo': # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = ' ' + if char == '<': sep = '>' + j = i + 1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c == sep): + break + j += 1 + if j < len(string) and string[j] == '>': + j += 1 + res.append(string[i:j]) + i = j + return res + + +def stripoff_tags(x): + if not x: return '' + chars = [] + i = 0 + T = len(x) + while i < T: + if x[i] == '<': + while i < T and x[i] != '>': + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return ''.join(chars) + + +def normalize(sentence, ignore_words, cs, split=None): + """ sentence, ignore_words are both in unicode + """ + new_sentence = [] + for token in sentence: + x = token + if not cs: + x = x.upper() + if x in ignore_words: + continue + if remove_tag: + x = stripoff_tags(x) + if not x: + continue + if split and x in split: + new_sentence += split[x] + else: + new_sentence.append(x) + return new_sentence + + +class Calculator: + def __init__(self): + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + + def calculate(self, lab, rec): + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab): + self.space.append([]) + for row in self.space: + for element in row: + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec): + row.append({'dist': 0, 'error': 'non'}) + for i in range(len(lab)): + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)): + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + for token in rec: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + # Computing edit distance + for i, lab_token in enumerate(lab): + for j, rec_token in enumerate(rec): + if i == 0 or j == 0: + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i - 1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist: + min_dist = dist + min_error = error + dist = self.space[i][j - 1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist: + min_dist = dist + min_error = error + if lab_token == rec_token: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] + error = 'cor' + else: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist: + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = { + 'lab': [], + 'rec': [], + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + i = len(lab) - 1 + j = len(rec) - 1 + while True: + if self.space[i][j]['error'] == 'cor': # correct + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub': # substitution + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del': # deletion + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, "") + i = i - 1 + elif self.space[i][j]['error'] == 'ins': # insertion + if len(rec[j]) > 0: + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, "") + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non': # starting point + break + else: # shouldn't reach here + print( + 'this should not happen , i = {i} , j = {j} , error = {error}'. + format(i=i, j=j, error=self.space[i][j]['error'])) + return result + + def overall(self): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def cluster(self, data): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in data: + if token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def keys(self): + return list(self.data.keys()) + + +def width(string): + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + + +def default_cluster(word): + unicode_names = [unicodedata.name(char) for char in word] + for i in reversed(range(len(unicode_names))): + if unicode_names[i].startswith('DIGIT'): # 1 + unicode_names[i] = 'Number' # 'DIGIT' + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or + unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')): + # 明 / 郎 + unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or + unicode_names[i].startswith('LATIN SMALL LETTER')): + # A / a + unicode_names[i] = 'English' # 'LATIN LETTER' + elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め + unicode_names[i] = 'Japanese' # 'GANA LETTER' + elif (unicode_names[i].startswith('AMPERSAND') or + unicode_names[i].startswith('APOSTROPHE') or + unicode_names[i].startswith('COMMERCIAL AT') or + unicode_names[i].startswith('DEGREE CELSIUS') or + unicode_names[i].startswith('EQUALS SIGN') or + unicode_names[i].startswith('FULL STOP') or + unicode_names[i].startswith('HYPHEN-MINUS') or + unicode_names[i].startswith('LOW LINE') or + unicode_names[i].startswith('NUMBER SIGN') or + unicode_names[i].startswith('PLUS SIGN') or + unicode_names[i].startswith('SEMICOLON')): + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else: + return 'Other' + if len(unicode_names) == 0: + return 'Other' + if len(unicode_names) == 1: + return unicode_names[0] + for i in range(len(unicode_names) - 1): + if unicode_names[i] != unicode_names[i + 1]: + return 'Other' + return unicode_names[0] + + +def usage(): + print( + "compute-wer.py : compute word error rate (WER) and align recognition results and references." + ) + print( + " usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer" + ) + + +def main(): + # python utils/compute-wer.py --char=1 --v=1 ref hyp > rsl.error + if len(sys.argv) == 1: + usage() + sys.exit(0) + calculator = Calculator() + cluster_file = '' + ignore_words = set() + tochar = False + verbose = 1 + padding_symbol = ' ' + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + while len(sys.argv) > 3: + a = '--maxw=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):] + del sys.argv[1] + max_words_per_line = int(b) + continue + a = '--rt=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + remove_tag = (b == 'true') or (b != '0') + continue + a = '--cs=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + case_sensitive = (b == 'true') or (b != '0') + continue + a = '--cluster=' + if sys.argv[1].startswith(a): + cluster_file = sys.argv[1][len(a):] + del sys.argv[1] + continue + a = '--splitfile=' + if sys.argv[1].startswith(a): + split_file = sys.argv[1][len(a):] + del sys.argv[1] + split = dict() + with codecs.open(split_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + words = line.strip().split() + if len(words) >= 2: + split[words[0]] = words[1:] + continue + a = '--ig=' + if sys.argv[1].startswith(a): + ignore_file = sys.argv[1][len(a):] + del sys.argv[1] + with codecs.open(ignore_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + line = line.strip() + if len(line) > 0: + ignore_words.add(line) + continue + a = '--char=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + tochar = (b == 'true') or (b != '0') + continue + a = '--v=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + verbose = 0 + try: + verbose = int(b) + except: + if b == 'true' or b != '0': + verbose = 1 + continue + a = '--padding-symbol=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + if b == 'space': + padding_symbol = ' ' + elif b == 'underline': + padding_symbol = '_' + continue + if True or sys.argv[1].startswith('-'): + #ignore invalid switch + del sys.argv[1] + continue + + if not case_sensitive: + ig = set([w.upper() for w in ignore_words]) + ignore_words = ig + + default_clusters = {} + default_words = {} + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + rec_set = {} + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit + + with codecs.open(hyp_file, 'r', 'utf-8') as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, + split) + + # compute error rate on the interaction of reference file and hyp file + for line in open(ref_file, 'r', encoding='utf-8'): + if tochar: + array = characterize(line) + else: + array = line.rstrip('\n').split() + if len(array) == 0: continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + if verbose: + print('\nutt: %s' % fid) + + for word in rec + lab: + if word not in default_words: + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters: + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name]: + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + + result = calculator.calculate(lab, rec) + if verbose: + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('WER: %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])): + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length - len_lab) + space['rec'].append(length - len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end=' ') + else: + print('lab:', end=' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['lab'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end=' ') + else: + print('rec:', end=' ') + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['rec'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 + + if verbose: + print( + '===========================================================================' + ) + print() + + result = calculator.overall() + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('Overall -> %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if not verbose: + print() + + if verbose: + for cluster_id in default_clusters: + result = calculator.cluster( + [k for k in default_clusters[cluster_id]]) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if len(cluster_file) > 0: # compute separated WERs for word clusters + cluster_id = '' + cluster = [] + for line in open(cluster_file, 'r', encoding='utf-8'): + for token in line.decode('utf-8').rstrip('\n').split(): + # end of cluster reached, like + if token[0:2] == '' and \ + token.lstrip('') == cluster_id : + result = calculator.cluster(cluster) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + cluster_id = '' + cluster = [] + # begin of cluster reached, like + elif token[0] == '<' and token[len(token)-1] == '>' and \ + cluster_id == '' : + cluster_id = token.lstrip('<').rstrip('>') + cluster = [] + # general terms, like WEATHER / CAR / ... + else: + cluster.append(token) + print() + print( + '===========================================================================' + ) + + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/format_data.py b/paddlespeech/dataset/s2t/format_data.py new file mode 100755 index 00000000..dcff66ea --- /dev/null +++ b/paddlespeech/dataset/s2t/format_data.py @@ -0,0 +1,154 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""format manifest with more metadata.""" +import argparse +import functools +import json + +import jsonlines + +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.frontend.utility import load_cmvn +from paddlespeech.s2t.io.utility import feat_type +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments + + +def define_argparse(): + parser = argparse.ArgumentParser(description=__doc__) + add_arg = functools.partial(add_arguments, argparser=parser) + # yapf: disable + add_arg('manifest_paths', str, + None, + "Filepaths of manifests for building vocabulary. " + "You can provide multiple manifest files.", + nargs='+', + required=True) + add_arg('output_path', str, None, "filepath of formated manifest.", required=True) + add_arg('cmvn_path', str, + 'examples/librispeech/data/mean_std.json', + "Filepath of cmvn.") + add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") + add_arg('vocab_path', str, + 'examples/librispeech/data/vocab.txt', + "Filepath of the vocabulary.") + # bpe + add_arg('spm_model_prefix', str, None, + "spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm") + + # yapf: disable + args = parser.parse_args() + return args + +def format_data( + manifest_paths="", + output_path="", + cmvn_path="examples/librispeech/data/mean_std.json", + unit_type="char", + vocab_path="examples/librispeech/data/vocab.txt", + spm_model_prefix=""): + + fout = open(output_path, 'w', encoding='utf-8') + + # get feat dim + filetype = cmvn_path.split(".")[-1] + mean, istd = load_cmvn(cmvn_path, filetype=filetype) + feat_dim = mean.shape[0] #(D) + print(f"Feature dim: {feat_dim}") + + text_feature = TextFeaturizer(unit_type, vocab_path, spm_model_prefix) + vocab_size = text_feature.vocab_size + print(f"Vocab size: {vocab_size}") + + # josnline like this + # { + # "input": [{"name": "input1", "shape": (100, 83), "feat": "xxx.ark:123"}], + # "output": [{"name":"target1", "shape": (40, 5002), "text": "a b c de"}], + # "utt2spk": "111-2222", + # "utt": "111-2222-333" + # } + count = 0 + for manifest_path in manifest_paths: + with jsonlines.open(str(manifest_path), 'r') as reader: + manifest_jsons = list(reader) + + for line_json in manifest_jsons: + output_json = { + "input": [], + "output": [], + 'utt': line_json['utt'], + 'utt2spk': line_json.get('utt2spk', 'global'), + } + + # output + line = line_json['text'] + if isinstance(line, str): + # only one target + tokens = text_feature.tokenize(line) + tokenids = text_feature.featurize(line) + output_json['output'].append({ + 'name': 'target1', + 'shape': (len(tokenids), vocab_size), + 'text': line, + 'token': ' '.join(tokens), + 'tokenid': ' '.join(map(str, tokenids)), + }) + else: + # isinstance(line, list), multi target in one vocab + for i, item in enumerate(line, 1): + tokens = text_feature.tokenize(item) + tokenids = text_feature.featurize(item) + output_json['output'].append({ + 'name': f'target{i}', + 'shape': (len(tokenids), vocab_size), + 'text': item, + 'token': ' '.join(tokens), + 'tokenid': ' '.join(map(str, tokenids)), + }) + + # input + line = line_json['feat'] + if isinstance(line, str): + # only one input + feat_shape = line_json['feat_shape'] + assert isinstance(feat_shape, (list, tuple)), type(feat_shape) + filetype = feat_type(line) + if filetype == 'sound': + feat_shape.append(feat_dim) + else: # kaldi + raise NotImplementedError('no support kaldi feat now!') + + output_json['input'].append({ + "name": "input1", + "shape": feat_shape, + "feat": line, + "filetype": filetype, + }) + else: + # isinstance(line, list), multi input + raise NotImplementedError("not support multi input now!") + + fout.write(json.dumps(output_json) + '\n') + count += 1 + + print(f"{manifest_paths} Examples number: {count}") + fout.close() + +def main(): + args = define_argparse() + print_arguments(args, globals()) + format_data(**vars(args)) + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/format_rsl.py b/paddlespeech/dataset/s2t/format_rsl.py new file mode 100644 index 00000000..0a58e7e6 --- /dev/null +++ b/paddlespeech/dataset/s2t/format_rsl.py @@ -0,0 +1,143 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +format ref/hyp file for `utt text` format to compute CER/WER/MER. + +norm: +BAC009S0764W0196 明确了发展目标和重点任务 +BAC009S0764W0186 实现我国房地产市场的平稳运行 + + +sclite: +加大对结构机械化环境和收集谈控机制力度(BAC009S0906W0240.wav) +河南省新乡市丰秋县刘光镇政府东五零左右(BAC009S0770W0441.wav) +""" +import argparse + +import jsonlines + +from paddlespeech.utils.argparse import print_arguments + + +def transform_hyp(origin, trans, trans_sclite): + """ + Args: + origin: The input json file which contains the model output + trans: The output file for caculate CER/WER + trans_sclite: The output file for caculate CER/WER using sclite + """ + input_dict = {} + + with open(origin, "r+", encoding="utf8") as f: + for item in jsonlines.Reader(f): + input_dict[item["utt"]] = item["hyps"][0] + + if trans: + with open(trans, "w+", encoding="utf8") as f: + for key in input_dict.keys(): + f.write(key + " " + input_dict[key] + "\n") + print(f"transform_hyp output: {trans}") + + if trans_sclite: + with open(trans_sclite, "w+") as f: + for key in input_dict.keys(): + line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" + f.write(line) + print(f"transform_hyp output: {trans_sclite}") + + +def transform_ref(origin, trans, trans_sclite): + """ + Args: + origin: The input json file which contains the model output + trans: The output file for caculate CER/WER + trans_sclite: The output file for caculate CER/WER using sclite + """ + input_dict = {} + + with open(origin, "r", encoding="utf8") as f: + for item in jsonlines.Reader(f): + input_dict[item["utt"]] = item["text"] + + if trans: + with open(trans, "w", encoding="utf8") as f: + for key in input_dict.keys(): + f.write(key + " " + input_dict[key] + "\n") + print(f"transform_hyp output: {trans}") + + if trans_sclite: + with open(trans_sclite, "w") as f: + for key in input_dict.keys(): + line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" + f.write(line) + print(f"transform_hyp output: {trans_sclite}") + + +def define_argparse(): + parser = argparse.ArgumentParser( + prog='format ref/hyp file for compute CER/WER', add_help=True) + parser.add_argument( + '--origin_hyp', type=str, default="", help='origin hyp file') + parser.add_argument( + '--trans_hyp', + type=str, + default="", + help='hyp file for caculating CER/WER') + parser.add_argument( + '--trans_hyp_sclite', + type=str, + default="", + help='hyp file for caculating CER/WER by sclite') + + parser.add_argument( + '--origin_ref', type=str, default="", help='origin ref file') + parser.add_argument( + '--trans_ref', + type=str, + default="", + help='ref file for caculating CER/WER') + parser.add_argument( + '--trans_ref_sclite', + type=str, + default="", + help='ref file for caculating CER/WER by sclite') + parser_args = parser.parse_args() + return parser_args + + +def format_result(origin_hyp="", + trans_hyp="", + trans_hyp_sclite="", + origin_ref="", + trans_ref="", + trans_ref_sclite=""): + + if origin_hyp: + transform_hyp( + origin=origin_hyp, trans=trans_hyp, trans_sclite=trans_hyp_sclite) + + if origin_ref: + transform_ref( + origin=origin_ref, trans=trans_ref, trans_sclite=trans_ref_sclite) + + +def main(): + args = define_argparse() + print_arguments(args, globals()) + + format_result(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/s2t/__init__.py b/paddlespeech/s2t/__init__.py index 6663bcf8..37d99226 100644 --- a/paddlespeech/s2t/__init__.py +++ b/paddlespeech/s2t/__init__.py @@ -267,7 +267,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor: if not hasattr(paddle.Tensor, 'to'): - logger.debug("register user to to paddle.Tensor, remove this when fixed!") + logger.debug("register user to paddle.Tensor, remove this when fixed!") setattr(paddle.Tensor, 'to', to) setattr(paddle.static.Variable, 'to', to) diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py index 5755a5f1..f6b1ed09 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py @@ -28,8 +28,8 @@ from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.socket_server import AsrRequestHandler from paddlespeech.s2t.utils.socket_server import AsrTCPServer from paddlespeech.s2t.utils.socket_server import warm_up_test -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments def init_predictor(args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py index 0d0b4f21..fc57399d 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py @@ -26,8 +26,8 @@ from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.socket_server import AsrRequestHandler from paddlespeech.s2t.utils.socket_server import AsrTCPServer from paddlespeech.s2t.utils.socket_server import warm_up_test -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments def start_server(config, args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/export.py b/paddlespeech/s2t/exps/deepspeech2/bin/export.py index 8acd46df..07228e98 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/export.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test.py b/paddlespeech/s2t/exps/deepspeech2/bin/test.py index 030168a9..a8e20ff9 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py index d7a9402b..1e07aa80 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py index 66ea29d0..32a583b6 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py @@ -27,8 +27,8 @@ from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils import mp_tools from paddlespeech.s2t.utils.checkpoint import Checkpoint from paddlespeech.s2t.utils.log import Log -from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.utils.argparse import print_arguments logger = Log(__name__).getlog() diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/train.py b/paddlespeech/s2t/exps/deepspeech2/bin/train.py index 2c9942f9..1340aaa3 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/train.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/train.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Trainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2/bin/alignment.py b/paddlespeech/s2t/exps/u2/bin/alignment.py index e3390feb..cc294038 100644 --- a/paddlespeech/s2t/exps/u2/bin/alignment.py +++ b/paddlespeech/s2t/exps/u2/bin/alignment.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2/bin/export.py b/paddlespeech/s2t/exps/u2/bin/export.py index 592b1237..4725e5e1 100644 --- a/paddlespeech/s2t/exps/u2/bin/export.py +++ b/paddlespeech/s2t/exps/u2/bin/export.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2/bin/test.py b/paddlespeech/s2t/exps/u2/bin/test.py index b13fd0d3..43eeff63 100644 --- a/paddlespeech/s2t/exps/u2/bin/test.py +++ b/paddlespeech/s2t/exps/u2/bin/test.py @@ -18,7 +18,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2/bin/train.py b/paddlespeech/s2t/exps/u2/bin/train.py index dc3a87c1..a0f50328 100644 --- a/paddlespeech/s2t/exps/u2/bin/train.py +++ b/paddlespeech/s2t/exps/u2/bin/train.py @@ -19,7 +19,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments # from paddlespeech.s2t.exps.u2.trainer import U2Trainer as Trainer diff --git a/paddlespeech/s2t/exps/u2_kaldi/bin/test.py b/paddlespeech/s2t/exps/u2_kaldi/bin/test.py index 422483b9..4137537e 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/bin/test.py +++ b/paddlespeech/s2t/exps/u2_kaldi/bin/test.py @@ -18,7 +18,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.dynamic_import import dynamic_import -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments model_test_alias = { "u2": "paddlespeech.s2t.exps.u2.model:U2Tester", diff --git a/paddlespeech/s2t/exps/u2_kaldi/bin/train.py b/paddlespeech/s2t/exps/u2_kaldi/bin/train.py index b11da715..011aabac 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/bin/train.py +++ b/paddlespeech/s2t/exps/u2_kaldi/bin/train.py @@ -19,7 +19,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.dynamic_import import dynamic_import -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments model_train_alias = { "u2": "paddlespeech.s2t.exps.u2.model:U2Trainer", diff --git a/paddlespeech/s2t/exps/u2_st/bin/export.py b/paddlespeech/s2t/exps/u2_st/bin/export.py index c641152f..a2a7424c 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/export.py +++ b/paddlespeech/s2t/exps/u2_st/bin/export.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2_st/bin/test.py b/paddlespeech/s2t/exps/u2_st/bin/test.py index c07c95bd..30a903ce 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/test.py +++ b/paddlespeech/s2t/exps/u2_st/bin/test.py @@ -18,7 +18,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2_st/bin/train.py b/paddlespeech/s2t/exps/u2_st/bin/train.py index 574942e5..b36a0af4 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/train.py +++ b/paddlespeech/s2t/exps/u2_st/bin/train.py @@ -19,7 +19,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2_st.model import U2STTrainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/test.py b/paddlespeech/s2t/exps/wav2vec2/bin/test.py index a376651d..c17cee0f 100644 --- a/paddlespeech/s2t/exps/wav2vec2/bin/test.py +++ b/paddlespeech/s2t/exps/wav2vec2/bin/test.py @@ -18,7 +18,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/train.py b/paddlespeech/s2t/exps/wav2vec2/bin/train.py index 29e7ef55..0c37f796 100644 --- a/paddlespeech/s2t/exps/wav2vec2/bin/train.py +++ b/paddlespeech/s2t/exps/wav2vec2/bin/train.py @@ -19,7 +19,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTrainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/frontend/augmentor/augmentation.py b/paddlespeech/s2t/frontend/augmentor/augmentation.py index 4c5ca4fe..744ea56d 100644 --- a/paddlespeech/s2t/frontend/augmentor/augmentation.py +++ b/paddlespeech/s2t/frontend/augmentor/augmentation.py @@ -45,7 +45,7 @@ class AugmentationPipeline(): samples to make the model invariant to certain types of perturbations in the real world, improving model's generalization ability. - The pipeline is built according the the augmentation configuration in json + The pipeline is built according to the augmentation configuration in json string, e.g. .. code-block:: diff --git a/paddlespeech/s2t/frontend/featurizer/text_featurizer.py b/paddlespeech/s2t/frontend/featurizer/text_featurizer.py index 982c6b8f..7623d0b8 100644 --- a/paddlespeech/s2t/frontend/featurizer/text_featurizer.py +++ b/paddlespeech/s2t/frontend/featurizer/text_featurizer.py @@ -48,13 +48,16 @@ class TextFeaturizer(): self.unit_type = unit_type self.unk = UNK self.maskctc = maskctc + self.vocab_path_or_list = vocab - if vocab: + if self.vocab_path_or_list: self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file( vocab, maskctc) self.vocab_size = len(self.vocab_list) else: - logger.warning("TextFeaturizer: not have vocab file or vocab list.") + logger.warning( + "TextFeaturizer: not have vocab file or vocab list. Only Tokenizer can use, can not convert to token idx" + ) if unit_type == 'spm': spm_model = spm_model_prefix + '.model' @@ -62,6 +65,7 @@ class TextFeaturizer(): self.sp.Load(spm_model) def tokenize(self, text, replace_space=True): + """tokenizer split text into text tokens""" if self.unit_type == 'char': tokens = self.char_tokenize(text, replace_space) elif self.unit_type == 'word': @@ -71,6 +75,7 @@ class TextFeaturizer(): return tokens def detokenize(self, tokens): + """tokenizer convert text tokens back to text""" if self.unit_type == 'char': text = self.char_detokenize(tokens) elif self.unit_type == 'word': @@ -88,6 +93,7 @@ class TextFeaturizer(): Returns: List[int]: List of token indices. """ + assert self.vocab_path_or_list, "toidx need vocab path or vocab list" tokens = self.tokenize(text) ids = [] for token in tokens: @@ -107,6 +113,7 @@ class TextFeaturizer(): Returns: str: Text. """ + assert self.vocab_path_or_list, "toidx need vocab path or vocab list" tokens = [] for idx in idxs: if idx == self.eos_id: @@ -127,10 +134,10 @@ class TextFeaturizer(): """ text = text.strip() if replace_space: - text_list = [SPACE if item == " " else item for item in list(text)] + tokens = [SPACE if item == " " else item for item in list(text)] else: - text_list = list(text) - return text_list + tokens = list(text) + return tokens def char_detokenize(self, tokens): """Character detokenizer. diff --git a/paddlespeech/s2t/io/speechbrain/sampler.py b/paddlespeech/s2t/io/speechbrain/sampler.py index ba13193e..09a884c2 100755 --- a/paddlespeech/s2t/io/speechbrain/sampler.py +++ b/paddlespeech/s2t/io/speechbrain/sampler.py @@ -283,7 +283,7 @@ class DynamicBatchSampler(Sampler): num_quantiles, ) # get quantiles using lognormal distribution quantiles = lognorm.ppf(latent_boundaries, 1) - # scale up to to max_batch_length + # scale up to max_batch_length bucket_boundaries = quantiles * max_batch_length / quantiles[-1] # compute resulting bucket length multipliers length_multipliers = [ diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 6494b530..f716fa3b 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -560,7 +560,7 @@ class U2BaseModel(ASRInterface, nn.Layer): [len(hyp[0]) for hyp in hyps], place=device, dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) - hyps_lens = hyps_lens + 1 # Add at begining + hyps_lens = hyps_lens + 1 # Add at beginning logger.debug( f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}") @@ -709,7 +709,7 @@ class U2BaseModel(ASRInterface, nn.Layer): hypothesis from ctc prefix beam search and one encoder output Args: hyps (paddle.Tensor): hyps from ctc prefix beam search, already - pad sos at the begining, (B, T) + pad sos at the beginning, (B, T) hyps_lens (paddle.Tensor): length of each hyp in hyps, (B) encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D) Returns: diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index 31defbba..b4c8c255 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -455,7 +455,7 @@ class U2STBaseModel(nn.Layer): hypothesis from ctc prefix beam search and one encoder output Args: hyps (paddle.Tensor): hyps from ctc prefix beam search, already - pad sos at the begining, (B, T) + pad sos at the beginning, (B, T) hyps_lens (paddle.Tensor): length of each hyp in hyps, (B) encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D) Returns: diff --git a/paddlespeech/s2t/utils/utility.py b/paddlespeech/s2t/utils/utility.py index d7e7c6ca..5655ec3f 100644 --- a/paddlespeech/s2t/utils/utility.py +++ b/paddlespeech/s2t/utils/utility.py @@ -29,10 +29,7 @@ from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() -__all__ = [ - "all_version", "UpdateConfig", "seed_all", 'print_arguments', - 'add_arguments', "log_add" -] +__all__ = ["all_version", "UpdateConfig", "seed_all", "log_add"] def all_version(): @@ -60,51 +57,6 @@ def seed_all(seed: int=20210329): paddle.seed(seed) -def print_arguments(args, info=None): - """Print argparse's arguments. - - Usage: - - .. code-block:: python - - parser = argparse.ArgumentParser() - parser.add_argument("name", default="Jonh", type=str, help="User name.") - args = parser.parse_args() - print_arguments(args) - - :param args: Input argparse.Namespace for printing. - :type args: argparse.Namespace - """ - filename = "" - if info: - filename = info["__file__"] - filename = os.path.basename(filename) - print(f"----------- {filename} Arguments -----------") - for arg, value in sorted(vars(args).items()): - print("%s: %s" % (arg, value)) - print("-----------------------------------------------------------") - - -def add_arguments(argname, type, default, help, argparser, **kwargs): - """Add argparse's argument. - - Usage: - - .. code-block:: python - - parser = argparse.ArgumentParser() - add_argument("name", str, "Jonh", "User name.", parser) - args = parser.parse_args() - """ - type = distutils.util.strtobool if type == bool else type - argparser.add_argument( - "--" + argname, - default=default, - type=type, - help=help + ' Default: %(default)s.', - **kwargs) - - def log_add(args: List[int]) -> float: """Stable log add diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index 536ffe0a..a702f0aa 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -609,7 +609,7 @@ class PaddleASRConnectionHanddler: dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.model.sos, self.model.eos, self.model.ignore_id) - hyps_lens = hyps_lens + 1 # Add at begining + hyps_lens = hyps_lens + 1 # Add at beginning # ctc score in ln domain # (beam_size, max_hyps_len, vocab_size) diff --git a/paddlespeech/server/ws/asr_api.py b/paddlespeech/server/ws/asr_api.py index ae1c8831..b3ad0b7c 100644 --- a/paddlespeech/server/ws/asr_api.py +++ b/paddlespeech/server/ws/asr_api.py @@ -67,7 +67,7 @@ async def websocket_endpoint(websocket: WebSocket): # and we break the loop if message['signal'] == 'start': resp = {"status": "ok", "signal": "server_ready"} - # do something at begining here + # do something at beginning here # create the instance to process the audio #connection_handler = PaddleASRConnectionHanddler(asr_model) connection_handler = asr_model.new_handler() diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 9ae791b4..fe5d977a 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -114,7 +114,7 @@ def erniesat_batch_fn(examples, ] span_bdy = paddle.to_tensor(span_bdy) - # dual_mask 的是混合中英时候同时 mask 语音和文本 + # dual_mask 的是混合中英时候同时 mask 语音和文本 # ernie sat 在实现跨语言的时候都 mask 了 if text_masking: masked_pos, text_masked_pos = phones_text_masking( @@ -153,7 +153,7 @@ def erniesat_batch_fn(examples, batch = { "text": text, "speech": speech, - # need to generate + # need to generate "masked_pos": masked_pos, "speech_mask": speech_mask, "text_mask": text_mask, @@ -415,10 +415,13 @@ def fastspeech2_multi_spk_batch_fn(examples): def diffsinger_single_spk_batch_fn(examples): - # fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"] + # fields = ["text", "note", "note_dur", "is_slur", "text_lengths", \ + # "speech", "speech_lengths", "durations", "pitch", "energy"] text = [np.array(item["text"], dtype=np.int64) for item in examples] note = [np.array(item["note"], dtype=np.int64) for item in examples] - note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples] + note_dur = [ + np.array(item["note_dur"], dtype=np.float32) for item in examples + ] is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples] speech = [np.array(item["speech"], dtype=np.float32) for item in examples] pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples] @@ -471,10 +474,13 @@ def diffsinger_single_spk_batch_fn(examples): def diffsinger_multi_spk_batch_fn(examples): - # fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spk_emb"] + # fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", \ + # "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spk_emb"] text = [np.array(item["text"], dtype=np.int64) for item in examples] note = [np.array(item["note"], dtype=np.int64) for item in examples] - note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples] + note_dur = [ + np.array(item["note_dur"], dtype=np.float32) for item in examples + ] is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples] speech = [np.array(item["speech"], dtype=np.float32) for item in examples] pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples] @@ -663,6 +669,211 @@ def vits_multi_spk_batch_fn(examples): return batch +def jets_single_spk_batch_fn(examples): + """ + Returns: + Dict[str, Any]: + - text (Tensor): Text index tensor (B, T_text). + - text_lengths (Tensor): Text length tensor (B,). + - feats (Tensor): Feature tensor (B, T_feats, aux_channels). + - feats_lengths (Tensor): Feature length tensor (B,). + - durations (Tensor): Feature tensor (B, T_text,). + - durations_lengths (Tensor): Durations length tensor (B,). + - pitch (Tensor): Feature tensor (B, pitch_length,). + - energy (Tensor): Feature tensor (B, energy_length,). + - speech (Tensor): Speech waveform tensor (B, T_wav). + + """ + # fields = ["text", "text_lengths", "feats", "feats_lengths", "durations", "pitch", "energy", "speech"] + text = [np.array(item["text"], dtype=np.int64) for item in examples] + feats = [np.array(item["feats"], dtype=np.float32) for item in examples] + durations = [ + np.array(item["durations"], dtype=np.int64) for item in examples + ] + pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples] + energy = [np.array(item["energy"], dtype=np.float32) for item in examples] + speech = [np.array(item["wave"], dtype=np.float32) for item in examples] + + text_lengths = [ + np.array(item["text_lengths"], dtype=np.int64) for item in examples + ] + feats_lengths = [ + np.array(item["feats_lengths"], dtype=np.int64) for item in examples + ] + + text = batch_sequences(text) + feats = batch_sequences(feats) + durations = batch_sequences(durations) + pitch = batch_sequences(pitch) + energy = batch_sequences(energy) + speech = batch_sequences(speech) + + # convert each batch to paddle.Tensor + text = paddle.to_tensor(text) + feats = paddle.to_tensor(feats) + durations = paddle.to_tensor(durations) + pitch = paddle.to_tensor(pitch) + energy = paddle.to_tensor(energy) + text_lengths = paddle.to_tensor(text_lengths) + feats_lengths = paddle.to_tensor(feats_lengths) + + batch = { + "text": text, + "text_lengths": text_lengths, + "feats": feats, + "feats_lengths": feats_lengths, + "durations": durations, + "durations_lengths": text_lengths, + "pitch": pitch, + "energy": energy, + "speech": speech, + } + return batch + + +def jets_multi_spk_batch_fn(examples): + """ + Returns: + Dict[str, Any]: + - text (Tensor): Text index tensor (B, T_text). + - text_lengths (Tensor): Text length tensor (B,). + - feats (Tensor): Feature tensor (B, T_feats, aux_channels). + - feats_lengths (Tensor): Feature length tensor (B,). + - durations (Tensor): Feature tensor (B, T_text,). + - durations_lengths (Tensor): Durations length tensor (B,). + - pitch (Tensor): Feature tensor (B, pitch_length,). + - energy (Tensor): Feature tensor (B, energy_length,). + - speech (Tensor): Speech waveform tensor (B, T_wav). + - spk_id (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + - spk_emb (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + """ + # fields = ["text", "text_lengths", "feats", "feats_lengths", "durations", "pitch", "energy", "speech", "spk_id"/"spk_emb"] + text = [np.array(item["text"], dtype=np.int64) for item in examples] + feats = [np.array(item["feats"], dtype=np.float32) for item in examples] + durations = [ + np.array(item["durations"], dtype=np.int64) for item in examples + ] + pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples] + energy = [np.array(item["energy"], dtype=np.float32) for item in examples] + speech = [np.array(item["wave"], dtype=np.float32) for item in examples] + text_lengths = [ + np.array(item["text_lengths"], dtype=np.int64) for item in examples + ] + feats_lengths = [ + np.array(item["feats_lengths"], dtype=np.int64) for item in examples + ] + + text = batch_sequences(text) + feats = batch_sequences(feats) + durations = batch_sequences(durations) + pitch = batch_sequences(pitch) + energy = batch_sequences(energy) + speech = batch_sequences(speech) + + # convert each batch to paddle.Tensor + text = paddle.to_tensor(text) + feats = paddle.to_tensor(feats) + durations = paddle.to_tensor(durations) + pitch = paddle.to_tensor(pitch) + energy = paddle.to_tensor(energy) + text_lengths = paddle.to_tensor(text_lengths) + feats_lengths = paddle.to_tensor(feats_lengths) + + batch = { + "text": text, + "text_lengths": text_lengths, + "feats": feats, + "feats_lengths": feats_lengths, + "durations": durations, + "durations_lengths": text_lengths, + "pitch": pitch, + "energy": energy, + "speech": speech, + } + # spk_emb has a higher priority than spk_id + if "spk_emb" in examples[0]: + spk_emb = [ + np.array(item["spk_emb"], dtype=np.float32) for item in examples + ] + spk_emb = batch_sequences(spk_emb) + spk_emb = paddle.to_tensor(spk_emb) + batch["spk_emb"] = spk_emb + elif "spk_id" in examples[0]: + spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples] + spk_id = paddle.to_tensor(spk_id) + batch["spk_id"] = spk_id + return batch + + +# 因为要传参数,所以需要额外构建 +def build_starganv2_vc_collate_fn(latent_dim: int=16, max_mel_length: int=192): + + return StarGANv2VCCollateFn( + latent_dim=latent_dim, max_mel_length=max_mel_length) + + +class StarGANv2VCCollateFn: + """Functor class of common_collate_fn()""" + + def __init__(self, latent_dim: int=16, max_mel_length: int=192): + self.latent_dim = latent_dim + self.max_mel_length = max_mel_length + + def random_clip(self, mel: np.array): + # [T, 80] + mel_length = mel.shape[0] + if mel_length > self.max_mel_length: + random_start = np.random.randint(0, + mel_length - self.max_mel_length) + + mel = mel[random_start:random_start + self.max_mel_length, :] + return mel + + def __call__(self, exmaples): + return self.starganv2_vc_batch_fn(exmaples) + + def starganv2_vc_batch_fn(self, examples): + batch_size = len(examples) + + label = [np.array(item["label"], dtype=np.int64) for item in examples] + ref_label = [ + np.array(item["ref_label"], dtype=np.int64) for item in examples + ] + + # 需要对 mel 进行裁剪 + mel = [self.random_clip(item["mel"]) for item in examples] + ref_mel = [self.random_clip(item["ref_mel"]) for item in examples] + ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples] + mel = batch_sequences(mel) + ref_mel = batch_sequences(ref_mel) + ref_mel_2 = batch_sequences(ref_mel_2) + + # convert each batch to paddle.Tensor + # (B,) + label = paddle.to_tensor(label) + ref_label = paddle.to_tensor(ref_label) + # [B, T, 80] -> [B, 1, 80, T] + mel = paddle.to_tensor(mel).transpose([0, 2, 1]).unsqueeze(1) + ref_mel = paddle.to_tensor(ref_mel).transpose([0, 2, 1]).unsqueeze(1) + ref_mel_2 = paddle.to_tensor(ref_mel_2).transpose( + [0, 2, 1]).unsqueeze(1) + + z_trg = paddle.randn([batch_size, self.latent_dim]) + z_trg2 = paddle.randn([batch_size, self.latent_dim]) + + batch = { + "x_real": mel, + "y_org": label, + "x_ref": ref_mel, + "x_ref2": ref_mel_2, + "y_trg": ref_label, + "z_trg": z_trg, + "z_trg2": z_trg2 + } + + return batch + + # for PaddleSlim def fastspeech2_single_spk_batch_fn_static(examples): text = [np.array(item["text"], dtype=np.int64) for item in examples] diff --git a/paddlespeech/t2s/datasets/data_table.py b/paddlespeech/t2s/datasets/data_table.py index c9815af2..4ac67546 100644 --- a/paddlespeech/t2s/datasets/data_table.py +++ b/paddlespeech/t2s/datasets/data_table.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random from multiprocessing import Manager from typing import Any from typing import Callable from typing import Dict from typing import List +import numpy as np from paddle.io import Dataset @@ -131,3 +133,54 @@ class DataTable(Dataset): The length of the dataset """ return len(self.data) + + +class StarGANv2VCDataTable(DataTable): + def __init__(self, data: List[Dict[str, Any]]): + super().__init__(data) + raw_data = data + spk_id_set = list(set([item['spk_id'] for item in raw_data])) + data_list_per_class = {} + for spk_id in spk_id_set: + data_list_per_class[spk_id] = [] + for item in raw_data: + for spk_id in spk_id_set: + if item['spk_id'] == spk_id: + data_list_per_class[spk_id].append(item) + self.data_list_per_class = data_list_per_class + + def __getitem__(self, idx: int) -> Dict[str, Any]: + """Get an example given an index. + Args: + idx (int): Index of the example to get + + Returns: + Dict[str, Any]: A converted example + """ + if self.use_cache and self.caches[idx] is not None: + return self.caches[idx] + + data = self._get_metadata(idx) + + # 裁剪放到 batch_fn 里面 + # 返回一个字典 + """ + {'utt_id': 'p225_111', 'spk_id': '1', 'speech': 'path of *.npy'} + """ + ref_data = random.choice(self.data) + ref_label = ref_data['spk_id'] + ref_data_2 = random.choice(self.data_list_per_class[ref_label]) + # mel_tensor, label, ref_mel_tensor, ref2_mel_tensor, ref_label + new_example = { + 'utt_id': data['utt_id'], + 'mel': np.load(data['speech']), + 'label': int(data['spk_id']), + 'ref_mel': np.load(ref_data['speech']), + 'ref_mel_2': np.load(ref_data_2['speech']), + 'ref_label': int(ref_label) + } + + if self.use_cache: + self.caches[idx] = new_example + + return new_example diff --git a/paddlespeech/t2s/exps/diffsinger/preprocess.py b/paddlespeech/t2s/exps/diffsinger/preprocess.py index be526eff..a60ad44d 100644 --- a/paddlespeech/t2s/exps/diffsinger/preprocess.py +++ b/paddlespeech/t2s/exps/diffsinger/preprocess.py @@ -354,6 +354,7 @@ def main(): mel_extractor=mel_extractor, pitch_extractor=pitch_extractor, energy_extractor=energy_extractor, + nprocs=args.num_cpu, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir, write_metadata_method=args.write_metadata_method) diff --git a/paddlespeech/t2s/exps/ernie_sat/preprocess.py b/paddlespeech/t2s/exps/ernie_sat/preprocess.py index 486ed13a..04bbc074 100644 --- a/paddlespeech/t2s/exps/ernie_sat/preprocess.py +++ b/paddlespeech/t2s/exps/ernie_sat/preprocess.py @@ -324,6 +324,7 @@ def main(): sentences=sentences, output_dir=dev_dump_dir, mel_extractor=mel_extractor, + nprocs=args.num_cpu, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) if test_wav_files: diff --git a/paddlespeech/t2s/exps/fastspeech2/preprocess.py b/paddlespeech/t2s/exps/fastspeech2/preprocess.py index 521b9a88..a2353242 100644 --- a/paddlespeech/t2s/exps/fastspeech2/preprocess.py +++ b/paddlespeech/t2s/exps/fastspeech2/preprocess.py @@ -382,6 +382,7 @@ def main(): mel_extractor=mel_extractor, pitch_extractor=pitch_extractor, energy_extractor=energy_extractor, + nprocs=args.num_cpu, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir, write_metadata_method=args.write_metadata_method) diff --git a/paddlespeech/t2s/exps/jets/__init__.py b/paddlespeech/t2s/exps/jets/__init__.py new file mode 100644 index 00000000..97043fd7 --- /dev/null +++ b/paddlespeech/t2s/exps/jets/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/t2s/exps/jets/inference.py b/paddlespeech/t2s/exps/jets/inference.py new file mode 100644 index 00000000..4f6882ed --- /dev/null +++ b/paddlespeech/t2s/exps/jets/inference.py @@ -0,0 +1,172 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from pathlib import Path + +import paddle +import soundfile as sf +from timer import timer + +from paddlespeech.t2s.exps.syn_utils import get_am_output +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_predictor +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.utils import str2bool + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Paddle Infernce with acoustic model & vocoder.") + # acoustic model + parser.add_argument( + '--am', + type=str, + default='jets_csmsc', + choices=['jets_csmsc', 'jets_aishell3'], + help='Choose acoustic model type of tts task.') + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--speaker_dict", type=str, default=None, help="speaker id map file.") + parser.add_argument( + '--spk_id', + type=int, + default=0, + help='spk id for multi speaker acoustic model') + # other + parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en or mix') + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line") + parser.add_argument( + "--add-blank", + type=str2bool, + default=True, + help="whether to add blank between phones") + parser.add_argument( + "--inference_dir", type=str, help="dir to save inference models") + parser.add_argument("--output_dir", type=str, help="output dir") + # inference + parser.add_argument( + "--use_trt", + type=str2bool, + default=False, + help="whether to use TensorRT or not in GPU", ) + parser.add_argument( + "--use_mkldnn", + type=str2bool, + default=False, + help="whether to use MKLDNN or not in CPU.", ) + parser.add_argument( + "--precision", + type=str, + default='fp32', + choices=['fp32', 'fp16', 'bf16', 'int8'], + help="mode of running") + parser.add_argument( + "--device", + default="gpu", + choices=["gpu", "cpu"], + help="Device selected for inference.", ) + parser.add_argument('--cpu_threads', type=int, default=1) + + args, _ = parser.parse_known_args() + return args + + +# only inference for models trained with csmsc now +def main(): + args = parse_args() + + paddle.set_device(args.device) + + # frontend + frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) + + # am_predictor + am_predictor = get_predictor( + model_dir=args.inference_dir, + model_file=args.am + ".pdmodel", + params_file=args.am + ".pdiparams", + device=args.device, + use_trt=args.use_trt, + use_mkldnn=args.use_mkldnn, + cpu_threads=args.cpu_threads, + precision=args.precision) + # model: {model_name}_{dataset} + am_dataset = args.am[args.am.rindex('_') + 1:] + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + sentences = get_sentences(text_file=args.text, lang=args.lang) + + merge_sentences = True + add_blank = args.add_blank + # jets's fs is 22050 + fs = 22050 + # warmup + for utt_id, sentence in sentences[:3]: + with timer() as t: + wav = get_am_output( + input=sentence, + am_predictor=am_predictor, + am=args.am, + frontend=frontend, + lang=args.lang, + merge_sentences=merge_sentences, + speaker_dict=args.speaker_dict, + spk_id=args.spk_id, ) + speed = wav.size / t.elapse + rtf = fs / speed + print( + f"{utt_id}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + + print("warm up done!") + + N = 0 + T = 0 + for utt_id, sentence in sentences: + with timer() as t: + wav = get_am_output( + input=sentence, + am_predictor=am_predictor, + am=args.am, + frontend=frontend, + lang=args.lang, + merge_sentences=merge_sentences, + speaker_dict=args.speaker_dict, + spk_id=args.spk_id, ) + + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + rtf = fs / speed + sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=fs) + print( + f"{utt_id}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + + print(f"{utt_id} done!") + print(f"generation speed: {N / T}Hz, RTF: {fs / (N / T) }") + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/jets/normalize.py b/paddlespeech/t2s/exps/jets/normalize.py new file mode 100644 index 00000000..8531f0db --- /dev/null +++ b/paddlespeech/t2s/exps/jets/normalize.py @@ -0,0 +1,163 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Normalize feature files and dump them.""" +import argparse +import logging +from operator import itemgetter +from pathlib import Path + +import jsonlines +import numpy as np +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from paddlespeech.t2s.datasets.data_table import DataTable + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)." + ) + parser.add_argument( + "--metadata", + type=str, + required=True, + help="directory including feature files to be normalized. " + "you need to specify either *-scp or rootdir.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump normalized feature files.") + parser.add_argument( + "--feats-stats", type=str, required=True, help="feats statistics file.") + parser.add_argument( + "--pitch-stats", type=str, required=True, help="pitch statistics file.") + parser.add_argument( + "--energy-stats", + type=str, + required=True, + help="energy statistics file.") + parser.add_argument( + "--phones-dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--speaker-dict", type=str, default=None, help="speaker id map file.") + + args = parser.parse_args() + + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + + # get dataset + with jsonlines.open(args.metadata, 'r') as reader: + metadata = list(reader) + dataset = DataTable( + metadata, + converters={ + "feats": np.load, + "pitch": np.load, + "energy": np.load, + "wave": str, + }) + logging.info(f"The number of files = {len(dataset)}.") + + # restore scaler + feats_scaler = StandardScaler() + feats_scaler.mean_ = np.load(args.feats_stats)[0] + feats_scaler.scale_ = np.load(args.feats_stats)[1] + feats_scaler.n_features_in_ = feats_scaler.mean_.shape[0] + + pitch_scaler = StandardScaler() + pitch_scaler.mean_ = np.load(args.pitch_stats)[0] + pitch_scaler.scale_ = np.load(args.pitch_stats)[1] + pitch_scaler.n_features_in_ = pitch_scaler.mean_.shape[0] + + energy_scaler = StandardScaler() + energy_scaler.mean_ = np.load(args.energy_stats)[0] + energy_scaler.scale_ = np.load(args.energy_stats)[1] + energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0] + + vocab_phones = {} + with open(args.phones_dict, 'rt') as f: + phn_id = [line.strip().split() for line in f.readlines()] + for phn, id in phn_id: + vocab_phones[phn] = int(id) + + vocab_speaker = {} + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + for spk, id in spk_id: + vocab_speaker[spk] = int(id) + + # process each file + output_metadata = [] + + for item in tqdm(dataset): + utt_id = item['utt_id'] + feats = item['feats'] + pitch = item['pitch'] + energy = item['energy'] + wave_path = item['wave'] + # normalize + feats = feats_scaler.transform(feats) + feats_dir = dumpdir / "data_feats" + feats_dir.mkdir(parents=True, exist_ok=True) + feats_path = feats_dir / f"{utt_id}_feats.npy" + np.save(feats_path, feats.astype(np.float32), allow_pickle=False) + + pitch = pitch_scaler.transform(pitch) + pitch_dir = dumpdir / "data_pitch" + pitch_dir.mkdir(parents=True, exist_ok=True) + pitch_path = pitch_dir / f"{utt_id}_pitch.npy" + np.save(pitch_path, pitch.astype(np.float32), allow_pickle=False) + + energy = energy_scaler.transform(energy) + energy_dir = dumpdir / "data_energy" + energy_dir.mkdir(parents=True, exist_ok=True) + energy_path = energy_dir / f"{utt_id}_energy.npy" + np.save(energy_path, energy.astype(np.float32), allow_pickle=False) + + phone_ids = [vocab_phones[p] for p in item['phones']] + spk_id = vocab_speaker[item["speaker"]] + record = { + "utt_id": item['utt_id'], + "spk_id": spk_id, + "text": phone_ids, + "text_lengths": item['text_lengths'], + "feats_lengths": item['feats_lengths'], + "durations": item['durations'], + "feats": str(feats_path), + "pitch": str(pitch_path), + "energy": str(energy_path), + "wave": str(wave_path), + } + # add spk_emb for voice cloning + if "spk_emb" in item: + record["spk_emb"] = str(item["spk_emb"]) + + output_metadata.append(record) + output_metadata.sort(key=itemgetter('utt_id')) + output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" + with jsonlines.open(output_metadata_path, 'w') as writer: + for item in output_metadata: + writer.write(item) + logging.info(f"metadata dumped into {output_metadata_path}") + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/jets/preprocess.py b/paddlespeech/t2s/exps/jets/preprocess.py new file mode 100644 index 00000000..468941ea --- /dev/null +++ b/paddlespeech/t2s/exps/jets/preprocess.py @@ -0,0 +1,451 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from concurrent.futures import ThreadPoolExecutor +from operator import itemgetter +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List + +import jsonlines +import librosa +import numpy as np +import tqdm +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.get_feats import Energy +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.get_feats import Pitch +from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length +from paddlespeech.t2s.datasets.preprocess_utils import get_input_token +from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur +from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map +from paddlespeech.t2s.datasets.preprocess_utils import merge_silence +from paddlespeech.t2s.utils import str2bool + + +def process_sentence(config: Dict[str, Any], + fp: Path, + sentences: Dict, + output_dir: Path, + mel_extractor=None, + pitch_extractor=None, + energy_extractor=None, + cut_sil: bool=True, + spk_emb_dir: Path=None, + token_average: bool=True): + utt_id = fp.stem + # for vctk + if utt_id.endswith("_mic2"): + utt_id = utt_id[:-5] + record = None + if utt_id in sentences: + # reading, resampling may occur + wav, _ = librosa.load( + str(fp), sr=config.fs, + mono=False) if "canton" in str(fp) else librosa.load( + str(fp), sr=config.fs) + if len(wav.shape) == 2 and "canton" in str(fp): + # Remind that Cantonese datasets should be placed in ~/datasets/canton_all. Otherwise, it may cause problem. + wav = wav[0] + wav = np.ascontiguousarray(wav) + elif len(wav.shape) != 1: + return record + max_value = np.abs(wav).max() + if max_value > 1.0: + wav = wav / max_value + assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio." + assert np.abs(wav).max( + ) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM." + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + speaker = sentences[utt_id][2] + d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant') + # little imprecise than use *.TextGrid directly + times = librosa.frames_to_time( + d_cumsum, sr=config.fs, hop_length=config.n_shift) + if cut_sil: + start = 0 + end = d_cumsum[-1] + if phones[0] == "sil" and len(durations) > 1: + start = times[1] + durations = durations[1:] + phones = phones[1:] + if phones[-1] == 'sil' and len(durations) > 1: + end = times[-2] + durations = durations[:-1] + phones = phones[:-1] + sentences[utt_id][0] = phones + sentences[utt_id][1] = durations + start, end = librosa.time_to_samples([start, end], sr=config.fs) + wav = wav[start:end] + # extract mel feats + logmel = mel_extractor.get_log_mel_fbank(wav) + # change duration according to mel_length + compare_duration_and_mel_length(sentences, utt_id, logmel) + # utt_id may be popped in compare_duration_and_mel_length + if utt_id not in sentences: + return None + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + num_frames = logmel.shape[0] + assert sum(durations) == num_frames + mel_dir = output_dir / "data_feats" + mel_dir.mkdir(parents=True, exist_ok=True) + mel_path = mel_dir / (utt_id + "_feats.npy") + np.save(mel_path, logmel) + + if wav.size < num_frames * config.n_shift: + wav = np.pad( + wav, (0, num_frames * config.n_shift - wav.size), + mode="reflect") + else: + wav = wav[:num_frames * config.n_shift] + wave_dir = output_dir / "data_wave" + wave_dir.mkdir(parents=True, exist_ok=True) + wav_path = wave_dir / (utt_id + "_wave.npy") + # (num_samples, ) + np.save(wav_path, wav) + # extract pitch and energy + if token_average == True: + f0 = pitch_extractor.get_pitch( + wav, + duration=np.array(durations), + use_token_averaged_f0=token_average) + if (f0 == 0).all(): + return None + assert f0.shape[0] == len(durations) + else: + f0 = pitch_extractor.get_pitch( + wav, use_token_averaged_f0=token_average) + if (f0 == 0).all(): + return None + f0 = f0[:num_frames] + assert f0.shape[0] == num_frames + f0_dir = output_dir / "data_pitch" + f0_dir.mkdir(parents=True, exist_ok=True) + f0_path = f0_dir / (utt_id + "_pitch.npy") + np.save(f0_path, f0) + if token_average == True: + energy = energy_extractor.get_energy( + wav, + duration=np.array(durations), + use_token_averaged_energy=token_average) + assert energy.shape[0] == len(durations) + else: + energy = energy_extractor.get_energy( + wav, use_token_averaged_energy=token_average) + energy = energy[:num_frames] + assert energy.shape[0] == num_frames + + energy_dir = output_dir / "data_energy" + energy_dir.mkdir(parents=True, exist_ok=True) + energy_path = energy_dir / (utt_id + "_energy.npy") + np.save(energy_path, energy) + record = { + "utt_id": utt_id, + "phones": phones, + "text_lengths": len(phones), + "feats_lengths": num_frames, + "durations": durations, + "feats": str(mel_path), + "pitch": str(f0_path), + "energy": str(energy_path), + "wave": str(wav_path), + "speaker": speaker + } + if spk_emb_dir: + if speaker in os.listdir(spk_emb_dir): + embed_name = utt_id + ".npy" + embed_path = spk_emb_dir / speaker / embed_name + if embed_path.is_file(): + record["spk_emb"] = str(embed_path) + else: + return None + return record + + +def process_sentences(config, + fps: List[Path], + sentences: Dict, + output_dir: Path, + mel_extractor=None, + pitch_extractor=None, + energy_extractor=None, + nprocs: int=1, + cut_sil: bool=True, + spk_emb_dir: Path=None, + write_metadata_method: str='w', + token_average: bool=True): + if nprocs == 1: + results = [] + for fp in tqdm.tqdm(fps, total=len(fps)): + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, + cut_sil=cut_sil, + spk_emb_dir=spk_emb_dir, + token_average=token_average) + if record: + results.append(record) + else: + with ThreadPoolExecutor(nprocs) as pool: + futures = [] + with tqdm.tqdm(total=len(fps)) as progress: + for fp in fps: + future = pool.submit(process_sentence, config, fp, + sentences, output_dir, mel_extractor, + pitch_extractor, energy_extractor, + cut_sil, spk_emb_dir) + future.add_done_callback(lambda p: progress.update()) + futures.append(future) + + results = [] + for ft in futures: + record = ft.result() + if record: + results.append(record) + + results.sort(key=itemgetter("utt_id")) + with jsonlines.open(output_dir / "metadata.jsonl", + write_metadata_method) as writer: + for item in results: + writer.write(item) + print("Done") + + +def main(): + # parse config and args + parser = argparse.ArgumentParser( + description="Preprocess audio and then extract features.") + + parser.add_argument( + "--dataset", + default="baker", + type=str, + help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now") + + parser.add_argument( + "--rootdir", default=None, type=str, help="directory to dataset.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump feature files.") + parser.add_argument( + "--dur-file", default=None, type=str, help="path to durations.txt.") + + parser.add_argument("--config", type=str, help="fastspeech2 config file.") + + parser.add_argument( + "--num-cpu", type=int, default=1, help="number of process.") + + parser.add_argument( + "--cut-sil", + type=str2bool, + default=True, + help="whether cut sil in the edge of audio") + + parser.add_argument( + "--spk_emb_dir", + default=None, + type=str, + help="directory to speaker embedding files.") + + parser.add_argument( + "--write_metadata_method", + default="w", + type=str, + choices=["w", "a"], + help="How the metadata.jsonl file is written.") + + parser.add_argument( + "--token_average", + type=str2bool, + default=False, + help="Average the energy and pitch accroding to durations") + args = parser.parse_args() + + rootdir = Path(args.rootdir).expanduser() + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + dur_file = Path(args.dur_file).expanduser() + + if args.spk_emb_dir: + spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve() + else: + spk_emb_dir = None + + assert rootdir.is_dir() + assert dur_file.is_file() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + sentences, speaker_set = get_phn_dur(dur_file) + + merge_silence(sentences) + phone_id_map_path = dumpdir / "phone_id_map.txt" + speaker_id_map_path = dumpdir / "speaker_id_map.txt" + get_input_token(sentences, phone_id_map_path, args.dataset) + get_spk_id_map(speaker_set, speaker_id_map_path) + + if args.dataset == "baker": + wav_files = sorted(list((rootdir / "Wave").rglob("*.wav"))) + # split data into 3 sections + num_train = 9800 + num_dev = 100 + train_wav_files = wav_files[:num_train] + dev_wav_files = wav_files[num_train:num_train + num_dev] + test_wav_files = wav_files[num_train + num_dev:] + elif args.dataset == "aishell3": + sub_num_dev = 5 + wav_dir = rootdir / "train" / "wav" + train_wav_files = [] + dev_wav_files = [] + test_wav_files = [] + for speaker in os.listdir(wav_dir): + wav_files = sorted(list((wav_dir / speaker).rglob("*.wav"))) + if len(wav_files) > 100: + train_wav_files += wav_files[:-sub_num_dev * 2] + dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev] + test_wav_files += wav_files[-sub_num_dev:] + else: + train_wav_files += wav_files + elif args.dataset == "canton": + sub_num_dev = 5 + wav_dir = rootdir / "WAV" + train_wav_files = [] + dev_wav_files = [] + test_wav_files = [] + for speaker in os.listdir(wav_dir): + wav_files = sorted(list((wav_dir / speaker).rglob("*.wav"))) + if len(wav_files) > 100: + train_wav_files += wav_files[:-sub_num_dev * 2] + dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev] + test_wav_files += wav_files[-sub_num_dev:] + else: + train_wav_files += wav_files + elif args.dataset == "ljspeech": + wav_files = sorted(list((rootdir / "wavs").rglob("*.wav"))) + # split data into 3 sections + num_train = 12900 + num_dev = 100 + train_wav_files = wav_files[:num_train] + dev_wav_files = wav_files[num_train:num_train + num_dev] + test_wav_files = wav_files[num_train + num_dev:] + elif args.dataset == "vctk": + sub_num_dev = 5 + wav_dir = rootdir / "wav48_silence_trimmed" + train_wav_files = [] + dev_wav_files = [] + test_wav_files = [] + for speaker in os.listdir(wav_dir): + wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac"))) + if len(wav_files) > 100: + train_wav_files += wav_files[:-sub_num_dev * 2] + dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev] + test_wav_files += wav_files[-sub_num_dev:] + else: + train_wav_files += wav_files + + else: + print("dataset should in {baker, aishell3, ljspeech, vctk} now!") + + train_dump_dir = dumpdir / "train" / "raw" + train_dump_dir.mkdir(parents=True, exist_ok=True) + dev_dump_dir = dumpdir / "dev" / "raw" + dev_dump_dir.mkdir(parents=True, exist_ok=True) + test_dump_dir = dumpdir / "test" / "raw" + test_dump_dir.mkdir(parents=True, exist_ok=True) + + # Extractor + mel_extractor = LogMelFBank( + sr=config.fs, + n_fft=config.n_fft, + hop_length=config.n_shift, + win_length=config.win_length, + window=config.window, + n_mels=config.n_mels, + fmin=config.fmin, + fmax=config.fmax) + pitch_extractor = Pitch( + sr=config.fs, + hop_length=config.n_shift, + f0min=config.f0min, + f0max=config.f0max) + energy_extractor = Energy( + n_fft=config.n_fft, + hop_length=config.n_shift, + win_length=config.win_length, + window=config.window) + + # process for the 3 sections + if train_wav_files: + process_sentences( + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, + nprocs=args.num_cpu, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir, + write_metadata_method=args.write_metadata_method, + token_average=args.token_average) + if dev_wav_files: + process_sentences( + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, + nprocs=args.num_cpu, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir, + write_metadata_method=args.write_metadata_method, + token_average=args.token_average) + if test_wav_files: + process_sentences( + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, + nprocs=args.num_cpu, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir, + write_metadata_method=args.write_metadata_method, + token_average=args.token_average) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/jets/synthesize.py b/paddlespeech/t2s/exps/jets/synthesize.py new file mode 100644 index 00000000..ef26414d --- /dev/null +++ b/paddlespeech/t2s/exps/jets/synthesize.py @@ -0,0 +1,153 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import soundfile as sf +import yaml +from timer import timer +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.models.jets import JETS +from paddlespeech.t2s.utils import str2bool + + +def evaluate(args): + + # construct dataset for evaluation + with jsonlines.open(args.test_metadata, 'r') as reader: + test_metadata = list(reader) + # Init body. + with open(args.config) as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + + fields = ["utt_id", "text"] + converters = {} + + spk_num = None + if args.speaker_dict is not None: + print("multiple speaker jets!") + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + fields += ["spk_id"] + elif args.voice_cloning: + print("Evaluating voice cloning!") + fields += ["spk_emb"] + else: + print("single speaker jets!") + print("spk_num:", spk_num) + + test_dataset = DataTable( + data=test_metadata, + fields=fields, + converters=converters, ) + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + odim = config.n_fft // 2 + 1 + config["model"]["generator_params"]["spks"] = spk_num + + jets = JETS(idim=vocab_size, odim=odim, **config["model"]) + jets.set_state_dict(paddle.load(args.ckpt)["main_params"]) + jets.eval() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + N = 0 + T = 0 + + for datum in test_dataset: + utt_id = datum["utt_id"] + phone_ids = paddle.to_tensor(datum["text"]) + with timer() as t: + with paddle.no_grad(): + spk_emb = None + spk_id = None + # multi speaker + if args.voice_cloning and "spk_emb" in datum: + spk_emb = paddle.to_tensor(np.load(datum["spk_emb"])) + elif "spk_id" in datum: + spk_id = paddle.to_tensor(datum["spk_id"]) + out = jets.inference( + text=phone_ids, sids=spk_id, spembs=spk_emb) + wav = out["wav"] + wav = wav.numpy() + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + rtf = config.fs / speed + print( + f"{utt_id}, wave: {wav.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs) + print(f"{utt_id} done!") + print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }") + + +def parse_args(): + # parse args and config + parser = argparse.ArgumentParser(description="Synthesize with JETS") + # model + parser.add_argument( + '--config', type=str, default=None, help='Config of JETS.') + parser.add_argument( + '--ckpt', type=str, default=None, help='Checkpoint file of JETS.') + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--speaker_dict", type=str, default=None, help="speaker id map file.") + parser.add_argument( + "--voice-cloning", + type=str2bool, + default=False, + help="whether training voice cloning model.") + # other + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument("--test_metadata", type=str, help="test metadata.") + parser.add_argument("--output_dir", type=str, help="output dir.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/jets/synthesize_e2e.py b/paddlespeech/t2s/exps/jets/synthesize_e2e.py new file mode 100644 index 00000000..1c713c06 --- /dev/null +++ b/paddlespeech/t2s/exps/jets/synthesize_e2e.py @@ -0,0 +1,189 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from pathlib import Path + +import paddle +import soundfile as sf +import yaml +from timer import timer +from yacs.config import CfgNode + +from paddlespeech.t2s.exps.syn_utils import am_to_static +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.models.jets import JETS +from paddlespeech.t2s.models.jets import JETSInference +from paddlespeech.t2s.utils import str2bool + + +def evaluate(args): + # Init body. + with open(args.config) as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + + sentences = get_sentences(text_file=args.text, lang=args.lang) + + # frontend + frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) + # acoustic model + am_name = args.am[:args.am.rindex('_')] + am_dataset = args.am[args.am.rindex('_') + 1:] + + spk_num = None + if args.speaker_dict is not None: + print("multiple speaker jets!") + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + else: + print("single speaker jets!") + print("spk_num:", spk_num) + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + odim = config.n_fft // 2 + 1 + config["model"]["generator_params"]["spks"] = spk_num + + jets = JETS(idim=vocab_size, odim=odim, **config["model"]) + jets.set_state_dict(paddle.load(args.ckpt)["main_params"]) + jets.eval() + + jets_inference = JETSInference(jets) + # whether dygraph to static + if args.inference_dir: + jets_inference = am_to_static( + am_inference=jets_inference, + am=args.am, + inference_dir=args.inference_dir, + speaker_dict=args.speaker_dict) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + merge_sentences = False + + N = 0 + T = 0 + for utt_id, sentence in sentences: + with timer() as t: + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + elif args.lang == 'en': + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + with paddle.no_grad(): + flags = 0 + for i in range(len(phone_ids)): + part_phone_ids = phone_ids[i] + spk_id = None + if am_dataset in {"aishell3", + "vctk"} and spk_num is not None: + spk_id = paddle.to_tensor(args.spk_id) + wav = jets_inference(part_phone_ids, spk_id) + else: + wav = jets_inference(part_phone_ids) + if flags == 0: + wav_all = wav + flags = 1 + else: + wav_all = paddle.concat([wav_all, wav]) + wav = wav_all.numpy() + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + rtf = config.fs / speed + print( + f"{utt_id}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs) + print(f"{utt_id} done!") + print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }") + + +def parse_args(): + # parse args and config + parser = argparse.ArgumentParser(description="Synthesize with JETS") + + # model + parser.add_argument( + '--config', type=str, default=None, help='Config of JETS.') + parser.add_argument( + '--ckpt', type=str, default=None, help='Checkpoint file of JETS.') + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--speaker_dict", type=str, default=None, help="speaker id map file.") + parser.add_argument( + '--spk_id', + type=int, + default=0, + help='spk id for multi speaker acoustic model') + # other + parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') + + parser.add_argument( + "--inference_dir", + type=str, + default=None, + help="dir to save inference models") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line.") + parser.add_argument("--output_dir", type=str, help="output dir.") + + parser.add_argument( + '--am', + type=str, + default='jets_csmsc', + help='Choose acoustic model type of tts task.') + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/jets/train.py b/paddlespeech/t2s/exps/jets/train.py new file mode 100644 index 00000000..7eb4031a --- /dev/null +++ b/paddlespeech/t2s/exps/jets/train.py @@ -0,0 +1,305 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import shutil +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import yaml +from paddle import DataParallel +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.optimizer import AdamW +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.am_batch_fn import jets_multi_spk_batch_fn +from paddlespeech.t2s.datasets.am_batch_fn import jets_single_spk_batch_fn +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.datasets.sampler import ErnieSATSampler +from paddlespeech.t2s.models.jets import JETS +from paddlespeech.t2s.models.jets import JETSEvaluator +from paddlespeech.t2s.models.jets import JETSUpdater +from paddlespeech.t2s.modules.losses import DiscriminatorAdversarialLoss +from paddlespeech.t2s.modules.losses import FeatureMatchLoss +from paddlespeech.t2s.modules.losses import ForwardSumLoss +from paddlespeech.t2s.modules.losses import GeneratorAdversarialLoss +from paddlespeech.t2s.modules.losses import MelSpectrogramLoss +from paddlespeech.t2s.modules.losses import VarianceLoss +from paddlespeech.t2s.training.extensions.snapshot import Snapshot +from paddlespeech.t2s.training.extensions.visualizer import VisualDL +from paddlespeech.t2s.training.optimizer import scheduler_classes +from paddlespeech.t2s.training.seeding import seed_everything +from paddlespeech.t2s.training.trainer import Trainer +from paddlespeech.t2s.utils import str2bool + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + world_size = paddle.distributed.get_world_size() + if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + fields = [ + "text", "text_lengths", "feats", "feats_lengths", "wave", "durations", + "pitch", "energy" + ] + + converters = { + "wave": np.load, + "feats": np.load, + "pitch": np.load, + "energy": np.load, + } + spk_num = None + if args.speaker_dict is not None: + print("multiple speaker jets!") + collate_fn = jets_multi_spk_batch_fn + with open(args.speaker_dict, 'rt', encoding='utf-8') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + fields += ["spk_id"] + elif args.voice_cloning: + print("Training voice cloning!") + collate_fn = jets_multi_spk_batch_fn + fields += ["spk_emb"] + converters["spk_emb"] = np.load + else: + print("single speaker jets!") + collate_fn = jets_single_spk_batch_fn + print("spk_num:", spk_num) + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=fields, + converters=converters, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=fields, + converters=converters, ) + + # collate function and dataloader + train_sampler = ErnieSATSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=True) + dev_sampler = ErnieSATSampler( + dev_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=False) + print("samplers done!") + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=collate_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + batch_sampler=dev_sampler, + collate_fn=collate_fn, + num_workers=config.num_workers) + print("dataloaders done!") + + with open(args.phones_dict, 'rt', encoding='utf-8') as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + odim = config.n_mels + config["model"]["generator_params"]["spks"] = spk_num + model = JETS(idim=vocab_size, odim=odim, **config["model"]) + gen_parameters = model.generator.parameters() + dis_parameters = model.discriminator.parameters() + if world_size > 1: + model = DataParallel(model) + gen_parameters = model._layers.generator.parameters() + dis_parameters = model._layers.discriminator.parameters() + + print("model done!") + + # loss + criterion_mel = MelSpectrogramLoss( + **config["mel_loss_params"], ) + criterion_feat_match = FeatureMatchLoss( + **config["feat_match_loss_params"], ) + criterion_gen_adv = GeneratorAdversarialLoss( + **config["generator_adv_loss_params"], ) + criterion_dis_adv = DiscriminatorAdversarialLoss( + **config["discriminator_adv_loss_params"], ) + criterion_var = VarianceLoss() + criterion_forwardsum = ForwardSumLoss() + + print("criterions done!") + + lr_schedule_g = scheduler_classes[config["generator_scheduler"]]( + **config["generator_scheduler_params"]) + optimizer_g = AdamW( + learning_rate=lr_schedule_g, + parameters=gen_parameters, + **config["generator_optimizer_params"]) + + lr_schedule_d = scheduler_classes[config["discriminator_scheduler"]]( + **config["discriminator_scheduler_params"]) + optimizer_d = AdamW( + learning_rate=lr_schedule_d, + parameters=dis_parameters, + **config["discriminator_optimizer_params"]) + + print("optimizers done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + updater = JETSUpdater( + model=model, + optimizers={ + "generator": optimizer_g, + "discriminator": optimizer_d, + }, + criterions={ + "mel": criterion_mel, + "feat_match": criterion_feat_match, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + "var": criterion_var, + "forwardsum": criterion_forwardsum, + }, + schedulers={ + "generator": lr_schedule_g, + "discriminator": lr_schedule_d, + }, + dataloader=train_dataloader, + lambda_adv=config.lambda_adv, + lambda_mel=config.lambda_mel, + lambda_feat_match=config.lambda_feat_match, + lambda_var=config.lambda_var, + lambda_align=config.lambda_align, + generator_first=config.generator_first, + use_alignment_module=config.use_alignment_module, + output_dir=output_dir) + + evaluator = JETSEvaluator( + model=model, + criterions={ + "mel": criterion_mel, + "feat_match": criterion_feat_match, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + "var": criterion_var, + "forwardsum": criterion_forwardsum, + }, + dataloader=dev_dataloader, + lambda_adv=config.lambda_adv, + lambda_mel=config.lambda_mel, + lambda_feat_match=config.lambda_feat_match, + lambda_var=config.lambda_var, + lambda_align=config.lambda_align, + generator_first=config.generator_first, + use_alignment_module=config.use_alignment_module, + output_dir=output_dir) + + trainer = Trainer( + updater, + stop_trigger=(config.train_max_steps, "iteration"), + out=output_dir) + + if dist.get_rank() == 0: + trainer.extend( + evaluator, trigger=(config.eval_interval_steps, 'iteration')) + trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration')) + trainer.extend( + Snapshot(max_size=config.num_snapshots), + trigger=(config.save_interval_steps, 'iteration')) + + print("Trainer Done!") + trainer.run() + + +def main(): + # parse args and config and redirect to train_sp + + parser = argparse.ArgumentParser(description="Train a JETS model.") + parser.add_argument("--config", type=str, help="JETS config file") + parser.add_argument("--train-metadata", type=str, help="training data.") + parser.add_argument("--dev-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument( + "--phones-dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--speaker-dict", + type=str, + default=None, + help="speaker id map file for multiple speaker model.") + + parser.add_argument( + "--voice-cloning", + type=str2bool, + default=False, + help="whether training voice cloning model.") + + args = parser.parse_args() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.ngpu > 1: + dist.spawn(train_sp, (args, config), nprocs=args.ngpu) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/speedyspeech/preprocess.py b/paddlespeech/t2s/exps/speedyspeech/preprocess.py index e4084c14..75a1b079 100644 --- a/paddlespeech/t2s/exps/speedyspeech/preprocess.py +++ b/paddlespeech/t2s/exps/speedyspeech/preprocess.py @@ -280,6 +280,7 @@ def main(): sentences=sentences, output_dir=dev_dump_dir, mel_extractor=mel_extractor, + nprocs=args.num_cpu, cut_sil=args.cut_sil, use_relative_path=args.use_relative_path) if test_wav_files: diff --git a/paddlespeech/t2s/exps/starganv2_vc/normalize.py b/paddlespeech/t2s/exps/starganv2_vc/normalize.py new file mode 100644 index 00000000..c063c46f --- /dev/null +++ b/paddlespeech/t2s/exps/starganv2_vc/normalize.py @@ -0,0 +1,101 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Normalize feature files and dump them.""" +import argparse +import logging +from operator import itemgetter +from pathlib import Path + +import jsonlines +import numpy as np +import tqdm + +from paddlespeech.t2s.datasets.data_table import DataTable + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)." + ) + parser.add_argument( + "--metadata", + type=str, + required=True, + help="directory including feature files to be normalized. " + "you need to specify either *-scp or rootdir.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump normalized feature files.") + + parser.add_argument( + "--speaker-dict", type=str, default=None, help="speaker id map file.") + + args = parser.parse_args() + + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + + # get dataset + with jsonlines.open(args.metadata, 'r') as reader: + metadata = list(reader) + dataset = DataTable( + metadata, converters={ + "speech": np.load, + }) + logging.info(f"The number of files = {len(dataset)}.") + + vocab_speaker = {} + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + for spk, id in spk_id: + vocab_speaker[spk] = int(id) + + # process each file + output_metadata = [] + + for item in tqdm.tqdm(dataset): + utt_id = item['utt_id'] + speech = item['speech'] + + # normalize + # 这里暂时写死 + mean, std = -4, 4 + speech = (speech - mean) / std + speech_path = dumpdir / f"{utt_id}_speech.npy" + np.save(speech_path, speech.astype(np.float32), allow_pickle=False) + + spk_id = vocab_speaker[item["speaker"]] + record = { + "utt_id": item['utt_id'], + "spk_id": spk_id, + "speech": str(speech_path), + } + + output_metadata.append(record) + output_metadata.sort(key=itemgetter('utt_id')) + output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" + with jsonlines.open(output_metadata_path, 'w') as writer: + for item in output_metadata: + writer.write(item) + logging.info(f"metadata dumped into {output_metadata_path}") + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/starganv2_vc/preprocess.py b/paddlespeech/t2s/exps/starganv2_vc/preprocess.py new file mode 100644 index 00000000..053c3b32 --- /dev/null +++ b/paddlespeech/t2s/exps/starganv2_vc/preprocess.py @@ -0,0 +1,214 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from concurrent.futures import ThreadPoolExecutor +from operator import itemgetter +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List + +import jsonlines +import librosa +import numpy as np +import tqdm +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map + +speaker_set = set() + + +def process_sentence(config: Dict[str, Any], + fp: Path, + output_dir: Path, + mel_extractor=None): + utt_id = fp.stem + # for vctk + if utt_id.endswith("_mic2"): + utt_id = utt_id[:-5] + speaker = utt_id.split('_')[0] + speaker_set.add(speaker) + # 需要额外获取 speaker + record = None + # reading, resampling may occur + # 源码的 bug, 读取的时候按照 24000 读取,但是提取 mel 的时候按照 16000 提取 + # 具体参考 https://github.com/PaddlePaddle/PaddleSpeech/blob/c7d24ba42c377fe4c0765c6b1faa202a9aeb136f/paddlespeech/t2s/exps/starganv2_vc/vc.py#L165 + # 之后需要换成按照 24000 读取和按照 24000 提取 mel + wav, _ = librosa.load(str(fp), sr=24000) + max_value = np.abs(wav).max() + if max_value > 1.0: + wav = wav / max_value + assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio." + assert np.abs( + wav).max() <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM." + # extract mel feats + # 注意这里 base = 'e', 后续需要换成 base='10', 我们其他 TTS 模型都是 base='10' + logmel = mel_extractor.get_log_mel_fbank(wav, base='e') + mel_path = output_dir / (utt_id + "_speech.npy") + np.save(mel_path, logmel) + record = {"utt_id": utt_id, "speech": str(mel_path), "speaker": speaker} + return record + + +def process_sentences( + config, + fps: List[Path], + output_dir: Path, + mel_extractor=None, + nprocs: int=1, ): + if nprocs == 1: + results = [] + for fp in tqdm.tqdm(fps, total=len(fps)): + record = process_sentence( + config=config, + fp=fp, + output_dir=output_dir, + mel_extractor=mel_extractor) + if record: + results.append(record) + else: + with ThreadPoolExecutor(nprocs) as pool: + futures = [] + with tqdm.tqdm(total=len(fps)) as progress: + for fp in fps: + future = pool.submit(process_sentence, config, fp, + output_dir, mel_extractor) + future.add_done_callback(lambda p: progress.update()) + futures.append(future) + + results = [] + for ft in futures: + record = ft.result() + if record: + results.append(record) + + results.sort(key=itemgetter("utt_id")) + with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer: + for item in results: + writer.write(item) + print("Done") + + +def main(): + # parse config and args + parser = argparse.ArgumentParser( + description="Preprocess audio and then extract features.") + + parser.add_argument( + "--dataset", + default="vctk", + type=str, + help="name of dataset, should in {vctk} now") + + parser.add_argument( + "--rootdir", default=None, type=str, help="directory to dataset.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump feature files.") + + parser.add_argument("--config", type=str, help="StarGANv2VC config file.") + + parser.add_argument( + "--num-cpu", type=int, default=1, help="number of process.") + + args = parser.parse_args() + + rootdir = Path(args.rootdir).expanduser() + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + + assert rootdir.is_dir() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + if args.dataset == "vctk": + sub_num_dev = 5 + wav_dir = rootdir / "wav48_silence_trimmed" + train_wav_files = [] + dev_wav_files = [] + test_wav_files = [] + # only for test + for speaker in os.listdir(wav_dir): + wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac"))) + if len(wav_files) > 100: + train_wav_files += wav_files[:-sub_num_dev * 2] + dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev] + test_wav_files += wav_files[-sub_num_dev:] + else: + train_wav_files += wav_files + + else: + print("dataset should in {vctk} now!") + + train_dump_dir = dumpdir / "train" / "raw" + train_dump_dir.mkdir(parents=True, exist_ok=True) + dev_dump_dir = dumpdir / "dev" / "raw" + dev_dump_dir.mkdir(parents=True, exist_ok=True) + test_dump_dir = dumpdir / "test" / "raw" + test_dump_dir.mkdir(parents=True, exist_ok=True) + + # Extractor + mel_extractor = LogMelFBank( + sr=config.fs, + n_fft=config.n_fft, + hop_length=config.n_shift, + win_length=config.win_length, + window=config.window, + n_mels=config.n_mels, + fmin=config.fmin, + fmax=config.fmax, + # None here + norm=config.norm, + htk=config.htk, + power=config.power) + + # process for the 3 sections + if train_wav_files: + process_sentences( + config=config, + fps=train_wav_files, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, + nprocs=args.num_cpu) + if dev_wav_files: + process_sentences( + config=config, + fps=dev_wav_files, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, + nprocs=args.num_cpu) + if test_wav_files: + process_sentences( + config=config, + fps=test_wav_files, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, + nprocs=args.num_cpu) + + speaker_id_map_path = dumpdir / "speaker_id_map.txt" + get_spk_id_map(speaker_set, speaker_id_map_path) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/starganv2_vc/train.py b/paddlespeech/t2s/exps/starganv2_vc/train.py new file mode 100644 index 00000000..94fa3032 --- /dev/null +++ b/paddlespeech/t2s/exps/starganv2_vc/train.py @@ -0,0 +1,274 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import shutil +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import yaml +from paddle import DataParallel +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.optimizer import AdamW +from paddle.optimizer.lr import OneCycleLR +from yacs.config import CfgNode + +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.resource.pretrained_models import StarGANv2VC_source +from paddlespeech.t2s.datasets.am_batch_fn import build_starganv2_vc_collate_fn +from paddlespeech.t2s.datasets.data_table import StarGANv2VCDataTable +from paddlespeech.t2s.models.starganv2_vc import ASRCNN +from paddlespeech.t2s.models.starganv2_vc import Discriminator +from paddlespeech.t2s.models.starganv2_vc import Generator +from paddlespeech.t2s.models.starganv2_vc import JDCNet +from paddlespeech.t2s.models.starganv2_vc import MappingNetwork +from paddlespeech.t2s.models.starganv2_vc import StarGANv2VCEvaluator +from paddlespeech.t2s.models.starganv2_vc import StarGANv2VCUpdater +from paddlespeech.t2s.models.starganv2_vc import StyleEncoder +from paddlespeech.t2s.training.extensions.snapshot import Snapshot +from paddlespeech.t2s.training.extensions.visualizer import VisualDL +from paddlespeech.t2s.training.seeding import seed_everything +from paddlespeech.t2s.training.trainer import Trainer +from paddlespeech.utils.env import MODEL_HOME + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + world_size = paddle.distributed.get_world_size() + if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + # to edit + fields = ["speech", "speech_lengths"] + converters = {"speech": np.load} + + collate_fn = build_starganv2_vc_collate_fn( + latent_dim=config['mapping_network_params']['latent_dim'], + max_mel_length=config['max_mel_length']) + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = StarGANv2VCDataTable(data=train_metadata) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = StarGANv2VCDataTable(data=dev_metadata) + + # collate function and dataloader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + + print("samplers done!") + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=collate_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + shuffle=False, + drop_last=False, + batch_size=config.batch_size, + collate_fn=collate_fn, + num_workers=config.num_workers) + + print("dataloaders done!") + + # load model + model_version = '1.0' + uncompress_path = download_and_decompress(StarGANv2VC_source[model_version], + MODEL_HOME) + # 根据 speaker 的个数修改 num_domains + # 源码的预训练模型和 default.yaml 里面默认是 20 + if args.speaker_dict is not None: + with open(args.speaker_dict, 'rt', encoding='utf-8') as f: + spk_id = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id) + print("spk_num:", spk_num) + config['mapping_network_params']['num_domains'] = spk_num + config['style_encoder_params']['num_domains'] = spk_num + config['discriminator_params']['num_domains'] = spk_num + + generator = Generator(**config['generator_params']) + mapping_network = MappingNetwork(**config['mapping_network_params']) + style_encoder = StyleEncoder(**config['style_encoder_params']) + discriminator = Discriminator(**config['discriminator_params']) + + # load pretrained model + jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz') + asr_model_dir = os.path.join(uncompress_path, 'asr.pdz') + + F0_model = JDCNet(num_class=1, seq_len=config['max_mel_length']) + F0_model.set_state_dict(paddle.load(jdc_model_dir)['main_params']) + F0_model.eval() + + asr_model = ASRCNN(**config['asr_params']) + asr_model.set_state_dict(paddle.load(asr_model_dir)['main_params']) + asr_model.eval() + + if world_size > 1: + generator = DataParallel(generator) + discriminator = DataParallel(discriminator) + print("models done!") + + lr_schedule_g = OneCycleLR(**config["generator_scheduler_params"]) + optimizer_g = AdamW( + learning_rate=lr_schedule_g, + parameters=generator.parameters(), + **config["generator_optimizer_params"]) + + lr_schedule_s = OneCycleLR(**config["style_encoder_scheduler_params"]) + optimizer_s = AdamW( + learning_rate=lr_schedule_s, + parameters=style_encoder.parameters(), + **config["style_encoder_optimizer_params"]) + + lr_schedule_m = OneCycleLR(**config["mapping_network_scheduler_params"]) + optimizer_m = AdamW( + learning_rate=lr_schedule_m, + parameters=mapping_network.parameters(), + **config["mapping_network_optimizer_params"]) + + lr_schedule_d = OneCycleLR(**config["discriminator_scheduler_params"]) + optimizer_d = AdamW( + learning_rate=lr_schedule_d, + parameters=discriminator.parameters(), + **config["discriminator_optimizer_params"]) + print("optimizers done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + updater = StarGANv2VCUpdater( + models={ + "generator": generator, + "style_encoder": style_encoder, + "mapping_network": mapping_network, + "discriminator": discriminator, + "F0_model": F0_model, + "asr_model": asr_model, + }, + optimizers={ + "generator": optimizer_g, + "style_encoder": optimizer_s, + "mapping_network": optimizer_m, + "discriminator": optimizer_d, + }, + schedulers={ + "generator": lr_schedule_g, + "style_encoder": lr_schedule_s, + "mapping_network": lr_schedule_m, + "discriminator": lr_schedule_d, + }, + dataloader=train_dataloader, + g_loss_params=config.loss_params.g_loss, + d_loss_params=config.loss_params.d_loss, + adv_cls_epoch=config.loss_params.adv_cls_epoch, + con_reg_epoch=config.loss_params.con_reg_epoch, + output_dir=output_dir) + + evaluator = StarGANv2VCEvaluator( + models={ + "generator": generator, + "style_encoder": style_encoder, + "mapping_network": mapping_network, + "discriminator": discriminator, + "F0_model": F0_model, + "asr_model": asr_model, + }, + dataloader=dev_dataloader, + g_loss_params=config.loss_params.g_loss, + d_loss_params=config.loss_params.d_loss, + adv_cls_epoch=config.loss_params.adv_cls_epoch, + con_reg_epoch=config.loss_params.con_reg_epoch, + output_dir=output_dir) + + trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) + + if dist.get_rank() == 0: + trainer.extend(evaluator, trigger=(1, "epoch")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) + trainer.extend( + Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) + print("Trainer Done!") + + trainer.run() + + +def main(): + # parse args and config and redirect to train_sp + + parser = argparse.ArgumentParser(description="Train a HiFiGAN model.") + parser.add_argument("--config", type=str, help="HiFiGAN config file.") + parser.add_argument("--train-metadata", type=str, help="training data.") + parser.add_argument("--dev-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument( + "--speaker-dict", + type=str, + default=None, + help="speaker id map file for multiple speaker model.") + + args = parser.parse_args() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.ngpu > 1: + dist.spawn(train_sp, (args, config), nprocs=args.ngpu) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/starganv2_vc/vc.py b/paddlespeech/t2s/exps/starganv2_vc/vc.py index ffb25741..24d3dcf8 100644 --- a/paddlespeech/t2s/exps/starganv2_vc/vc.py +++ b/paddlespeech/t2s/exps/starganv2_vc/vc.py @@ -57,9 +57,10 @@ def get_mel_extractor(): def preprocess(wave, mel_extractor): + # (T, 80) logmel = mel_extractor.get_log_mel_fbank(wave, base='e') - # [1, 80, 1011] mean, std = -4, 4 + # [1, 80, T] mel_tensor = (paddle.to_tensor(logmel.T).unsqueeze(0) - mean) / std return mel_tensor @@ -67,6 +68,7 @@ def preprocess(wave, mel_extractor): def compute_style(speaker_dicts, mel_extractor, style_encoder, mapping_network): reference_embeddings = {} for key, (path, speaker) in speaker_dicts.items(): + # path = '' if path == '': label = paddle.to_tensor([speaker], dtype=paddle.int64) latent_dim = mapping_network.shared[0].weight.shape[0] @@ -164,6 +166,15 @@ def voice_conversion(args, uncompress_path): wave, sr = librosa.load(args.source_path, sr=24000) source = preprocess(wave=wave, mel_extractor=mel_extractor) + # # 测试 preprocess.py 的输出是否 ok + # # 直接用 raw 然后 norm 的在这里 ok + # # 直接用 norm 在这里 ok + # import numpy as np + # source = np.load("~/PaddleSpeech_stargan_preprocess/PaddleSpeech/examples/vctk/vc3/dump/train/norm/p329_414_speech.npy") + # # !!!对 mel_extractor norm 后的操作 + # # [1, 80, T] + # source = paddle.to_tensor(source.T).unsqueeze(0) + output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) orig_wav_name = str(output_dir / 'orig_voc.wav') diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index 2b958b56..57c79dee 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -506,7 +506,7 @@ def am_to_static(am_inference, am_inference = jit.to_static( am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) - elif am_name == 'vits': + elif am_name == 'vits' or am_name == 'jets': if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None: am_inference = jit.to_static( am_inference, diff --git a/paddlespeech/t2s/exps/tacotron2/preprocess.py b/paddlespeech/t2s/exps/tacotron2/preprocess.py index c27b9769..46b72591 100644 --- a/paddlespeech/t2s/exps/tacotron2/preprocess.py +++ b/paddlespeech/t2s/exps/tacotron2/preprocess.py @@ -311,6 +311,7 @@ def main(): sentences=sentences, output_dir=dev_dump_dir, mel_extractor=mel_extractor, + nprocs=args.num_cpu, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) if test_wav_files: diff --git a/paddlespeech/t2s/exps/vits/lite_predict.py b/paddlespeech/t2s/exps/vits/lite_predict.py index 790cd48e..32a544b7 100644 --- a/paddlespeech/t2s/exps/vits/lite_predict.py +++ b/paddlespeech/t2s/exps/vits/lite_predict.py @@ -21,6 +21,7 @@ from paddlespeech.t2s.exps.lite_syn_utils import get_lite_am_output from paddlespeech.t2s.exps.lite_syn_utils import get_lite_predictor from paddlespeech.t2s.exps.syn_utils import get_frontend from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.utils import str2bool def parse_args(): @@ -75,12 +76,12 @@ def main(): # frontend frontend = get_frontend( lang=args.lang, - phones_dict=args.phones_dict, - tones_dict=args.tones_dict) + phones_dict=args.phones_dict) # am_predictor + # vits can only run in arm am_predictor = get_lite_predictor( - model_dir=args.inference_dir, model_file=args.am + "_x86.nb") + model_dir=args.inference_dir, model_file=args.am + "_arm.nb") # model: {model_name}_{dataset} am_dataset = args.am[args.am.rindex('_') + 1:] diff --git a/paddlespeech/t2s/exps/vits/preprocess.py b/paddlespeech/t2s/exps/vits/preprocess.py index d6b226a2..23c959d4 100644 --- a/paddlespeech/t2s/exps/vits/preprocess.py +++ b/paddlespeech/t2s/exps/vits/preprocess.py @@ -321,6 +321,7 @@ def main(): sentences=sentences, output_dir=dev_dump_dir, spec_extractor=spec_extractor, + nprocs=args.num_cpu, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) if test_wav_files: diff --git a/paddlespeech/t2s/frontend/generate_lexicon.py b/paddlespeech/t2s/frontend/generate_lexicon.py index 6b467d00..4fb748a6 100644 --- a/paddlespeech/t2s/frontend/generate_lexicon.py +++ b/paddlespeech/t2s/frontend/generate_lexicon.py @@ -45,7 +45,7 @@ def rule(C, V, R, T): 'u' in syllables when certain conditions are satisfied. 'i' is distinguished when appeared in phonemes, and separated into 3 categories, 'i', 'ii' and 'iii'. - Erhua is is possibly applied to every finals, except for finals that already ends with 'r'. + Erhua is possibly applied to every finals, except for finals that already ends with 'r'. When a syllable is impossible or does not have any characters with this pronunciation, return None to filter it out. """ diff --git a/paddlespeech/t2s/models/hifigan/hifigan.py b/paddlespeech/t2s/models/hifigan/hifigan.py index 7a01840e..2759af9d 100644 --- a/paddlespeech/t2s/models/hifigan/hifigan.py +++ b/paddlespeech/t2s/models/hifigan/hifigan.py @@ -37,8 +37,8 @@ class HiFiGANGenerator(nn.Layer): channels: int=512, global_channels: int=-1, kernel_size: int=7, - upsample_scales: List[int]=(8, 8, 2, 2), - upsample_kernel_sizes: List[int]=(16, 16, 4, 4), + upsample_scales: List[int]=(5, 5, 4, 3), + upsample_kernel_sizes: List[int]=(10, 10, 8, 6), resblock_kernel_sizes: List[int]=(3, 7, 11), resblock_dilations: List[List[int]]=[(1, 3, 5), (1, 3, 5), (1, 3, 5)], @@ -47,8 +47,13 @@ class HiFiGANGenerator(nn.Layer): nonlinear_activation: str="leakyrelu", nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.1}, use_weight_norm: bool=True, - init_type: str="xavier_uniform", ): + init_type: str="xavier_uniform", + use_istft: bool=False, + istft_layer_id: int=2, + n_fft: int=2048, + win_length: int=1200, ): """Initialize HiFiGANGenerator module. + Args: in_channels (int): Number of input channels. @@ -79,6 +84,14 @@ class HiFiGANGenerator(nn.Layer): use_weight_norm (bool): Whether to use weight norm. If set to true, it will be applied to all of the conv layers. + use_istft (bool): + If set to true, it will be a iSTFTNet based on hifigan. + istft_layer_id (int): + Use istft after istft_layer_id layers of upsample layer if use_istft=True + n_fft (int): + Number of fft points in feature extraction + win_length (int): + Window length in feature extraction """ super().__init__() @@ -89,9 +102,11 @@ class HiFiGANGenerator(nn.Layer): assert kernel_size % 2 == 1, "Kernel size must be odd number." assert len(upsample_scales) == len(upsample_kernel_sizes) assert len(resblock_dilations) == len(resblock_kernel_sizes) + assert len(upsample_scales) >= istft_layer_id if use_istft else True # define modules - self.num_upsamples = len(upsample_kernel_sizes) + self.num_upsamples = len( + upsample_kernel_sizes) if not use_istft else istft_layer_id self.num_blocks = len(resblock_kernel_sizes) self.input_conv = nn.Conv1D( in_channels, @@ -101,7 +116,7 @@ class HiFiGANGenerator(nn.Layer): padding=(kernel_size - 1) // 2, ) self.upsamples = nn.LayerList() self.blocks = nn.LayerList() - for i in range(len(upsample_kernel_sizes)): + for i in range(self.num_upsamples): assert upsample_kernel_sizes[i] == 2 * upsample_scales[i] self.upsamples.append( nn.Sequential( @@ -126,15 +141,36 @@ class HiFiGANGenerator(nn.Layer): nonlinear_activation=nonlinear_activation, nonlinear_activation_params=nonlinear_activation_params, )) - self.output_conv = nn.Sequential( - nn.LeakyReLU(), - nn.Conv1D( + self.use_istft = use_istft + if self.use_istft: + self.istft_hop_size = 1 + for j in range(istft_layer_id, len(upsample_scales)): + self.istft_hop_size *= upsample_scales[j] + s = 1 + for j in range(istft_layer_id): + s *= upsample_scales[j] + self.istft_n_fft = int(n_fft / s) if ( + n_fft / s) % 2 == 0 else int((n_fft / s + 2) - n_fft / s % 2) + self.istft_win_length = int(win_length / s) if ( + win_length / + s) % 2 == 0 else int((win_length / s + 2) - win_length / s % 2) + self.reflection_pad = nn.Pad1D(padding=[1, 0], mode='reflect') + self.output_conv = nn.Conv1D( channels // (2**(i + 1)), - out_channels, + (self.istft_n_fft // 2 + 1) * 2, kernel_size, 1, - padding=(kernel_size - 1) // 2, ), - nn.Tanh(), ) + padding=(kernel_size - 1) // 2, ) + else: + self.output_conv = nn.Sequential( + nn.LeakyReLU(), + nn.Conv1D( + channels // (2**(i + 1)), + out_channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, ), + nn.Tanh(), ) if global_channels > 0: self.global_conv = nn.Conv1D(global_channels, channels, 1) @@ -167,7 +203,29 @@ class HiFiGANGenerator(nn.Layer): for j in range(self.num_blocks): cs += self.blocks[i * self.num_blocks + j](c) c = cs / self.num_blocks - c = self.output_conv(c) + + if self.use_istft: + c = F.leaky_relu(c) + c = self.reflection_pad(c) + c = self.output_conv(c) + """ + Input of Exp operator, an N-D Tensor, with data type float32, float64 or float16. + https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/exp_en.html + Use Euler's formula to implement spec*paddle.exp(1j*phase) + """ + spec = paddle.exp(c[:, :self.istft_n_fft // 2 + 1, :]) + phase = paddle.sin(c[:, self.istft_n_fft // 2 + 1:, :]) + + c = paddle.complex(spec * (paddle.cos(phase)), + spec * (paddle.sin(phase))) + c = paddle.signal.istft( + c, + n_fft=self.istft_n_fft, + hop_length=self.istft_hop_size, + win_length=self.istft_win_length) + c = c.unsqueeze(1) + else: + c = self.output_conv(c) return c diff --git a/paddlespeech/t2s/models/jets/__init__.py b/paddlespeech/t2s/models/jets/__init__.py new file mode 100644 index 00000000..dec4a331 --- /dev/null +++ b/paddlespeech/t2s/models/jets/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .jets import * +from .jets_updater import * diff --git a/paddlespeech/t2s/models/jets/alignments.py b/paddlespeech/t2s/models/jets/alignments.py new file mode 100644 index 00000000..998f67e2 --- /dev/null +++ b/paddlespeech/t2s/models/jets/alignments.py @@ -0,0 +1,182 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generator module in JETS. + +This code is based on https://github.com/imdanboy/jets. + +""" +import numpy as np +import paddle +import paddle.nn.functional as F +from numba import jit +from paddle import nn + +from paddlespeech.t2s.modules.masked_fill import masked_fill + + +class AlignmentModule(nn.Layer): + """Alignment Learning Framework proposed for parallel TTS models in: + https://arxiv.org/abs/2108.10447 + """ + + def __init__(self, adim, odim): + super().__init__() + self.t_conv1 = nn.Conv1D(adim, adim, kernel_size=3, padding=1) + self.t_conv2 = nn.Conv1D(adim, adim, kernel_size=1, padding=0) + + self.f_conv1 = nn.Conv1D(odim, adim, kernel_size=3, padding=1) + self.f_conv2 = nn.Conv1D(adim, adim, kernel_size=3, padding=1) + self.f_conv3 = nn.Conv1D(adim, adim, kernel_size=1, padding=0) + + def forward(self, text, feats, x_masks=None): + """ + Args: + text (Tensor): Batched text embedding (B, T_text, adim) + feats (Tensor): Batched acoustic feature (B, T_feats, odim) + x_masks (Tensor): Mask tensor (B, T_text) + + Returns: + Tensor: log probability of attention matrix (B, T_feats, T_text) + """ + + text = text.transpose((0, 2, 1)) + text = F.relu(self.t_conv1(text)) + text = self.t_conv2(text) + text = text.transpose((0, 2, 1)) + + feats = feats.transpose((0, 2, 1)) + feats = F.relu(self.f_conv1(feats)) + feats = F.relu(self.f_conv2(feats)) + feats = self.f_conv3(feats) + feats = feats.transpose((0, 2, 1)) + + dist = feats.unsqueeze(2) - text.unsqueeze(1) + dist = paddle.linalg.norm(dist, p=2, axis=3) + score = -dist + + if x_masks is not None: + x_masks = x_masks.unsqueeze(-2) + score = masked_fill(score, x_masks, -np.inf) + log_p_attn = F.log_softmax(score, axis=-1) + return log_p_attn, score + + +@jit(nopython=True) +def _monotonic_alignment_search(log_p_attn): + # https://arxiv.org/abs/2005.11129 + T_mel = log_p_attn.shape[0] + T_inp = log_p_attn.shape[1] + Q = np.full((T_inp, T_mel), fill_value=-np.inf) + + log_prob = log_p_attn.transpose(1, 0) # -> (T_inp,T_mel) + # 1. Q <- init first row for all j + for j in range(T_mel): + Q[0, j] = log_prob[0, :j + 1].sum() + + # 2. + for j in range(1, T_mel): + for i in range(1, min(j + 1, T_inp)): + Q[i, j] = max(Q[i - 1, j - 1], Q[i, j - 1]) + log_prob[i, j] + + # 3. + A = np.full((T_mel, ), fill_value=T_inp - 1) + for j in range(T_mel - 2, -1, -1): # T_mel-2, ..., 0 + # 'i' in {A[j+1]-1, A[j+1]} + i_a = A[j + 1] - 1 + i_b = A[j + 1] + if i_b == 0: + argmax_i = 0 + elif Q[i_a, j] >= Q[i_b, j]: + argmax_i = i_a + else: + argmax_i = i_b + A[j] = argmax_i + return A + + +def viterbi_decode(log_p_attn, text_lengths, feats_lengths): + """ + Args: + log_p_attn (Tensor): + Batched log probability of attention matrix (B, T_feats, T_text) + text_lengths (Tensor): + Text length tensor (B,) + feats_legnths (Tensor): + Feature length tensor (B,) + Returns: + Tensor: + Batched token duration extracted from `log_p_attn` (B,T_text) + Tensor: + binarization loss tensor () + """ + B = log_p_attn.shape[0] + T_text = log_p_attn.shape[2] + device = log_p_attn.place + + bin_loss = 0 + ds = paddle.zeros((B, T_text), dtype="int32") + for b in range(B): + cur_log_p_attn = log_p_attn[b, :feats_lengths[b], :text_lengths[b]] + viterbi = _monotonic_alignment_search(cur_log_p_attn.numpy()) + _ds = np.bincount(viterbi) + ds[b, :len(_ds)] = paddle.to_tensor( + _ds, place=device, dtype="int32") + + t_idx = paddle.arange(feats_lengths[b]) + bin_loss = bin_loss - cur_log_p_attn[t_idx, viterbi].mean() + bin_loss = bin_loss / B + return ds, bin_loss + + +@jit(nopython=True) +def _average_by_duration(ds, xs, text_lengths, feats_lengths): + B = ds.shape[0] + # xs_avg = np.zeros_like(ds) + xs_avg = np.zeros(shape=ds.shape, dtype=np.float32) + ds = ds.astype(np.int32) + for b in range(B): + t_text = text_lengths[b] + t_feats = feats_lengths[b] + d = ds[b, :t_text] + d_cumsum = d.cumsum() + d_cumsum = [0] + list(d_cumsum) + x = xs[b, :t_feats] + for n, (start, end) in enumerate(zip(d_cumsum[:-1], d_cumsum[1:])): + if len(x[start:end]) != 0: + xs_avg[b, n] = x[start:end].mean() + else: + xs_avg[b, n] = 0 + return xs_avg + + +def average_by_duration(ds, xs, text_lengths, feats_lengths): + """ + Args: + ds (Tensor): + Batched token duration (B,T_text) + xs (Tensor): + Batched feature sequences to be averaged (B,T_feats) + text_lengths (Tensor): + Text length tensor (B,) + feats_lengths (Tensor): + Feature length tensor (B,) + Returns: + Tensor: Batched feature averaged according to the token duration (B, T_text) + """ + device = ds.place + args = [ds, xs, text_lengths, feats_lengths] + args = [arg.numpy() for arg in args] + xs_avg = _average_by_duration(*args) + xs_avg = paddle.to_tensor(xs_avg, place=device) + return xs_avg diff --git a/paddlespeech/t2s/models/jets/generator.py b/paddlespeech/t2s/models/jets/generator.py new file mode 100644 index 00000000..9580d17d --- /dev/null +++ b/paddlespeech/t2s/models/jets/generator.py @@ -0,0 +1,897 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generator module in JETS. + +This code is based on https://github.com/imdanboy/jets. + +""" +import logging +import math +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple + +import numpy as np +import paddle +from paddle import nn +from typeguard import check_argument_types + +from paddlespeech.t2s.models.hifigan import HiFiGANGenerator +from paddlespeech.t2s.models.jets.alignments import AlignmentModule +from paddlespeech.t2s.models.jets.alignments import average_by_duration +from paddlespeech.t2s.models.jets.alignments import viterbi_decode +from paddlespeech.t2s.models.jets.length_regulator import GaussianUpsampling +from paddlespeech.t2s.modules.nets_utils import get_random_segments +from paddlespeech.t2s.modules.nets_utils import initialize +from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask +from paddlespeech.t2s.modules.nets_utils import make_pad_mask +from paddlespeech.t2s.modules.predictor.duration_predictor import DurationPredictor +from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator +from paddlespeech.t2s.modules.predictor.variance_predictor import VariancePredictor +from paddlespeech.t2s.modules.style_encoder import StyleEncoder +from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding +from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding +from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder +from paddlespeech.t2s.modules.transformer.encoder import TransformerEncoder + + +class JETSGenerator(nn.Layer): + """Generator module in JETS. + """ + + def __init__( + self, + idim: int, + odim: int, + adim: int=256, + aheads: int=2, + elayers: int=4, + eunits: int=1024, + dlayers: int=4, + dunits: int=1024, + positionwise_layer_type: str="conv1d", + positionwise_conv_kernel_size: int=1, + use_scaled_pos_enc: bool=True, + use_batch_norm: bool=True, + encoder_normalize_before: bool=True, + decoder_normalize_before: bool=True, + encoder_concat_after: bool=False, + decoder_concat_after: bool=False, + reduction_factor: int=1, + encoder_type: str="transformer", + decoder_type: str="transformer", + transformer_enc_dropout_rate: float=0.1, + transformer_enc_positional_dropout_rate: float=0.1, + transformer_enc_attn_dropout_rate: float=0.1, + transformer_dec_dropout_rate: float=0.1, + transformer_dec_positional_dropout_rate: float=0.1, + transformer_dec_attn_dropout_rate: float=0.1, + transformer_activation_type: str="relu", + # only for conformer + conformer_rel_pos_type: str="legacy", + conformer_pos_enc_layer_type: str="rel_pos", + conformer_self_attn_layer_type: str="rel_selfattn", + conformer_activation_type: str="swish", + use_macaron_style_in_conformer: bool=True, + use_cnn_in_conformer: bool=True, + zero_triu: bool=False, + conformer_enc_kernel_size: int=7, + conformer_dec_kernel_size: int=31, + # duration predictor + duration_predictor_layers: int=2, + duration_predictor_chans: int=384, + duration_predictor_kernel_size: int=3, + duration_predictor_dropout_rate: float=0.1, + # energy predictor + energy_predictor_layers: int=2, + energy_predictor_chans: int=384, + energy_predictor_kernel_size: int=3, + energy_predictor_dropout: float=0.5, + energy_embed_kernel_size: int=9, + energy_embed_dropout: float=0.5, + stop_gradient_from_energy_predictor: bool=False, + # pitch predictor + pitch_predictor_layers: int=2, + pitch_predictor_chans: int=384, + pitch_predictor_kernel_size: int=3, + pitch_predictor_dropout: float=0.5, + pitch_embed_kernel_size: int=9, + pitch_embed_dropout: float=0.5, + stop_gradient_from_pitch_predictor: bool=False, + # extra embedding related + spks: Optional[int]=None, + langs: Optional[int]=None, + spk_embed_dim: Optional[int]=None, + spk_embed_integration_type: str="add", + use_gst: bool=False, + gst_tokens: int=10, + gst_heads: int=4, + gst_conv_layers: int=6, + gst_conv_chans_list: Sequence[int]=(32, 32, 64, 64, 128, 128), + gst_conv_kernel_size: int=3, + gst_conv_stride: int=2, + gst_gru_layers: int=1, + gst_gru_units: int=128, + # training related + init_type: str="xavier_uniform", + init_enc_alpha: float=1.0, + init_dec_alpha: float=1.0, + use_masking: bool=False, + use_weighted_masking: bool=False, + segment_size: int=64, + # hifigan generator + generator_out_channels: int=1, + generator_channels: int=512, + generator_global_channels: int=-1, + generator_kernel_size: int=7, + generator_upsample_scales: List[int]=[8, 8, 2, 2], + generator_upsample_kernel_sizes: List[int]=[16, 16, 4, 4], + generator_resblock_kernel_sizes: List[int]=[3, 7, 11], + generator_resblock_dilations: List[List[int]]=[[1, 3, 5], [1, 3, 5], + [1, 3, 5]], + generator_use_additional_convs: bool=True, + generator_bias: bool=True, + generator_nonlinear_activation: str="LeakyReLU", + generator_nonlinear_activation_params: Dict[ + str, Any]={"negative_slope": 0.1}, + generator_use_weight_norm: bool=True, ): + """Initialize JETS generator module. + + Args: + idim (int): + Dimension of the inputs. + odim (int): + Dimension of the outputs. + adim (int): + Attention dimension. + aheads (int): + Number of attention heads. + elayers (int): + Number of encoder layers. + eunits (int): + Number of encoder hidden units. + dlayers (int): + Number of decoder layers. + dunits (int): + Number of decoder hidden units. + use_scaled_pos_enc (bool): + Whether to use trainable scaled pos encoding. + use_batch_norm (bool): + Whether to use batch normalization in encoder prenet. + encoder_normalize_before (bool): + Whether to apply layernorm layer before encoder block. + decoder_normalize_before (bool): + Whether to apply layernorm layer before decoder block. + encoder_concat_after (bool): + Whether to concatenate attention layer's input and output in encoder. + decoder_concat_after (bool): + Whether to concatenate attention layer's input and output in decoder. + reduction_factor (int): + Reduction factor. + encoder_type (str): + Encoder type ("transformer" or "conformer"). + decoder_type (str): + Decoder type ("transformer" or "conformer"). + transformer_enc_dropout_rate (float): + Dropout rate in encoder except attention and positional encoding. + transformer_enc_positional_dropout_rate (float): + Dropout rate after encoder positional encoding. + transformer_enc_attn_dropout_rate (float): + Dropout rate in encoder self-attention module. + transformer_dec_dropout_rate (float): + Dropout rate in decoder except attention & positional encoding. + transformer_dec_positional_dropout_rate (float): + Dropout rate after decoder positional encoding. + transformer_dec_attn_dropout_rate (float): + Dropout rate in decoder self-attention module. + conformer_rel_pos_type (str): + Relative pos encoding type in conformer. + conformer_pos_enc_layer_type (str): + Pos encoding layer type in conformer. + conformer_self_attn_layer_type (str): + Self-attention layer type in conformer + conformer_activation_type (str): + Activation function type in conformer. + use_macaron_style_in_conformer: + Whether to use macaron style FFN. + use_cnn_in_conformer: + Whether to use CNN in conformer. + zero_triu: + Whether to use zero triu in relative self-attention module. + conformer_enc_kernel_size: + Kernel size of encoder conformer. + conformer_dec_kernel_size: + Kernel size of decoder conformer. + duration_predictor_layers (int): + Number of duration predictor layers. + duration_predictor_chans (int): + Number of duration predictor channels. + duration_predictor_kernel_size (int): + Kernel size of duration predictor. + duration_predictor_dropout_rate (float): + Dropout rate in duration predictor. + pitch_predictor_layers (int): + Number of pitch predictor layers. + pitch_predictor_chans (int): + Number of pitch predictor channels. + pitch_predictor_kernel_size (int): + Kernel size of pitch predictor. + pitch_predictor_dropout_rate (float): + Dropout rate in pitch predictor. + pitch_embed_kernel_size (float): + Kernel size of pitch embedding. + pitch_embed_dropout_rate (float): + Dropout rate for pitch embedding. + stop_gradient_from_pitch_predictor: + Whether to stop gradient from pitch predictor to encoder. + energy_predictor_layers (int): + Number of energy predictor layers. + energy_predictor_chans (int): + Number of energy predictor channels. + energy_predictor_kernel_size (int): + Kernel size of energy predictor. + energy_predictor_dropout_rate (float): + Dropout rate in energy predictor. + energy_embed_kernel_size (float): + Kernel size of energy embedding. + energy_embed_dropout_rate (float): + Dropout rate for energy embedding. + stop_gradient_from_energy_predictor: + Whether to stop gradient from energy predictor to encoder. + spks (Optional[int]): + Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. + langs (Optional[int]): + Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. + spk_embed_dim (Optional[int]): + Speaker embedding dimension. If set to > 0, assume that spembs will be provided as the input. + spk_embed_integration_type: + How to integrate speaker embedding. + use_gst (str): + Whether to use global style token. + gst_tokens (int): + The number of GST embeddings. + gst_heads (int): + The number of heads in GST multihead attention. + gst_conv_layers (int): + The number of conv layers in GST. + gst_conv_chans_list: (Sequence[int]): + List of the number of channels of conv layers in GST. + gst_conv_kernel_size (int): + Kernel size of conv layers in GST. + gst_conv_stride (int): + Stride size of conv layers in GST. + gst_gru_layers (int): + The number of GRU layers in GST. + gst_gru_units (int): + The number of GRU units in GST. + init_type (str): + How to initialize transformer parameters. + init_enc_alpha (float): + Initial value of alpha in scaled pos encoding of the encoder. + init_dec_alpha (float): + Initial value of alpha in scaled pos encoding of the decoder. + use_masking (bool): + Whether to apply masking for padded part in loss calculation. + use_weighted_masking (bool): + Whether to apply weighted masking in loss calculation. + segment_size (int): + Segment size for random windowed discriminator + generator_out_channels (int): + Number of output channels. + generator_channels (int): + Number of hidden representation channels. + generator_global_channels (int): + Number of global conditioning channels. + generator_kernel_size (int): + Kernel size of initial and final conv layer. + generator_upsample_scales (List[int]): + List of upsampling scales. + generator_upsample_kernel_sizes (List[int]): + List of kernel sizes for upsample layers. + generator_resblock_kernel_sizes (List[int]): + List of kernel sizes for residual blocks. + generator_resblock_dilations (List[List[int]]): + List of list of dilations for residual blocks. + generator_use_additional_convs (bool): + Whether to use additional conv layers in residual blocks. + generator_bias (bool): + Whether to add bias parameter in convolution layers. + generator_nonlinear_activation (str): + Activation function module name. + generator_nonlinear_activation_params (Dict[str, Any]): + Hyperparameters for activation function. + generator_use_weight_norm (bool): + Whether to use weight norm. If set to true, it will be applied to all of the conv layers. + + """ + super().__init__() + self.segment_size = segment_size + self.upsample_factor = int(np.prod(generator_upsample_scales)) + self.idim = idim + self.odim = odim + self.reduction_factor = reduction_factor + self.encoder_type = encoder_type + self.decoder_type = decoder_type + self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor + self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor + self.use_scaled_pos_enc = use_scaled_pos_enc + self.use_gst = use_gst + + # use idx 0 as padding idx + self.padding_idx = 0 + + # get positional encoding layer type + transformer_pos_enc_layer_type = "scaled_abs_pos" if self.use_scaled_pos_enc else "abs_pos" + + # check relative positional encoding compatibility + if "conformer" in [encoder_type, decoder_type]: + if conformer_rel_pos_type == "legacy": + if conformer_pos_enc_layer_type == "rel_pos": + conformer_pos_enc_layer_type = "legacy_rel_pos" + logging.warning( + "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " + "due to the compatibility. If you want to use the new one, " + "please use conformer_pos_enc_layer_type = 'latest'.") + if conformer_self_attn_layer_type == "rel_selfattn": + conformer_self_attn_layer_type = "legacy_rel_selfattn" + logging.warning( + "Fallback to " + "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " + "due to the compatibility. If you want to use the new one, " + "please use conformer_pos_enc_layer_type = 'latest'.") + elif conformer_rel_pos_type == "latest": + assert conformer_pos_enc_layer_type != "legacy_rel_pos" + assert conformer_self_attn_layer_type != "legacy_rel_selfattn" + else: + raise ValueError( + f"Unknown rel_pos_type: {conformer_rel_pos_type}") + + # define encoder + encoder_input_layer = nn.Embedding( + num_embeddings=idim, + embedding_dim=adim, + padding_idx=self.padding_idx) + if encoder_type == "transformer": + self.encoder = TransformerEncoder( + idim=idim, + attention_dim=adim, + attention_heads=aheads, + linear_units=eunits, + num_blocks=elayers, + input_layer=encoder_input_layer, + dropout_rate=transformer_enc_dropout_rate, + positional_dropout_rate=transformer_enc_positional_dropout_rate, + attention_dropout_rate=transformer_enc_attn_dropout_rate, + pos_enc_layer_type=transformer_pos_enc_layer_type, + normalize_before=encoder_normalize_before, + concat_after=encoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + activation_type=transformer_activation_type) + elif encoder_type == "conformer": + self.encoder = ConformerEncoder( + idim=idim, + attention_dim=adim, + attention_heads=aheads, + linear_units=eunits, + num_blocks=elayers, + input_layer=encoder_input_layer, + dropout_rate=transformer_enc_dropout_rate, + positional_dropout_rate=transformer_enc_positional_dropout_rate, + attention_dropout_rate=transformer_enc_attn_dropout_rate, + normalize_before=encoder_normalize_before, + concat_after=encoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + macaron_style=use_macaron_style_in_conformer, + pos_enc_layer_type=conformer_pos_enc_layer_type, + selfattention_layer_type=conformer_self_attn_layer_type, + activation_type=conformer_activation_type, + use_cnn_module=use_cnn_in_conformer, + cnn_module_kernel=conformer_enc_kernel_size, + zero_triu=zero_triu, ) + else: + raise ValueError(f"{encoder_type} is not supported.") + + # define GST + if self.use_gst: + self.gst = StyleEncoder( + idim=odim, # the input is mel-spectrogram + gst_tokens=gst_tokens, + gst_token_dim=adim, + gst_heads=gst_heads, + conv_layers=gst_conv_layers, + conv_chans_list=gst_conv_chans_list, + conv_kernel_size=gst_conv_kernel_size, + conv_stride=gst_conv_stride, + gru_layers=gst_gru_layers, + gru_units=gst_gru_units, ) + + # define spk and lang embedding + self.spks = None + if spks is not None and spks > 1: + self.spks = spks + self.sid_emb = nn.Embedding(spks, adim) + self.langs = None + if langs is not None and langs > 1: + self.langs = langs + self.lid_emb = nn.Embedding(langs, adim) + + # define additional projection for speaker embedding + self.spk_embed_dim = None + if spk_embed_dim is not None and spk_embed_dim > 0: + self.spk_embed_dim = spk_embed_dim + self.spk_embed_integration_type = spk_embed_integration_type + if self.spk_embed_dim is not None: + if self.spk_embed_integration_type == "add": + self.projection = nn.Linear(self.spk_embed_dim, adim) + else: + self.projection = nn.Linear(adim + self.spk_embed_dim, adim) + + # define duration predictor + self.duration_predictor = DurationPredictor( + idim=adim, + n_layers=duration_predictor_layers, + n_chans=duration_predictor_chans, + kernel_size=duration_predictor_kernel_size, + dropout_rate=duration_predictor_dropout_rate, ) + + # define pitch predictor + self.pitch_predictor = VariancePredictor( + idim=adim, + n_layers=pitch_predictor_layers, + n_chans=pitch_predictor_chans, + kernel_size=pitch_predictor_kernel_size, + dropout_rate=pitch_predictor_dropout, ) + # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg + self.pitch_embed = nn.Sequential( + nn.Conv1D( + in_channels=1, + out_channels=adim, + kernel_size=pitch_embed_kernel_size, + padding=(pitch_embed_kernel_size - 1) // 2, ), + nn.Dropout(pitch_embed_dropout), ) + + # define energy predictor + self.energy_predictor = VariancePredictor( + idim=adim, + n_layers=energy_predictor_layers, + n_chans=energy_predictor_chans, + kernel_size=energy_predictor_kernel_size, + dropout_rate=energy_predictor_dropout, ) + # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg + self.energy_embed = nn.Sequential( + nn.Conv1D( + in_channels=1, + out_channels=adim, + kernel_size=energy_embed_kernel_size, + padding=(energy_embed_kernel_size - 1) // 2, ), + nn.Dropout(energy_embed_dropout), ) + + # define length regulator + self.length_regulator = GaussianUpsampling() + + # define decoder + # NOTE: we use encoder as decoder + # because fastspeech's decoder is the same as encoder + if decoder_type == "transformer": + self.decoder = TransformerEncoder( + idim=0, + attention_dim=adim, + attention_heads=aheads, + linear_units=dunits, + num_blocks=dlayers, + # in decoder, don't need layer before pos_enc_class (we use embedding here in encoder) + input_layer=None, + dropout_rate=transformer_dec_dropout_rate, + positional_dropout_rate=transformer_dec_positional_dropout_rate, + attention_dropout_rate=transformer_dec_attn_dropout_rate, + pos_enc_layer_type=transformer_pos_enc_layer_type, + normalize_before=decoder_normalize_before, + concat_after=decoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + activation_type=conformer_activation_type, ) + + elif decoder_type == "conformer": + self.decoder = ConformerEncoder( + idim=0, + attention_dim=adim, + attention_heads=aheads, + linear_units=dunits, + num_blocks=dlayers, + input_layer=None, + dropout_rate=transformer_dec_dropout_rate, + positional_dropout_rate=transformer_dec_positional_dropout_rate, + attention_dropout_rate=transformer_dec_attn_dropout_rate, + normalize_before=decoder_normalize_before, + concat_after=decoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + macaron_style=use_macaron_style_in_conformer, + pos_enc_layer_type=conformer_pos_enc_layer_type, + selfattention_layer_type=conformer_self_attn_layer_type, + activation_type=conformer_activation_type, + use_cnn_module=use_cnn_in_conformer, + cnn_module_kernel=conformer_dec_kernel_size, ) + else: + raise ValueError(f"{decoder_type} is not supported.") + + self.generator = HiFiGANGenerator( + in_channels=adim, + out_channels=generator_out_channels, + channels=generator_channels, + global_channels=generator_global_channels, + kernel_size=generator_kernel_size, + upsample_scales=generator_upsample_scales, + upsample_kernel_sizes=generator_upsample_kernel_sizes, + resblock_kernel_sizes=generator_resblock_kernel_sizes, + resblock_dilations=generator_resblock_dilations, + use_additional_convs=generator_use_additional_convs, + bias=generator_bias, + nonlinear_activation=generator_nonlinear_activation, + nonlinear_activation_params=generator_nonlinear_activation_params, + use_weight_norm=generator_use_weight_norm, ) + + self.alignment_module = AlignmentModule(adim, odim) + + # initialize parameters + self._reset_parameters( + init_type=init_type, + init_enc_alpha=init_enc_alpha, + init_dec_alpha=init_dec_alpha, ) + + def forward( + self, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + feats: paddle.Tensor, + feats_lengths: paddle.Tensor, + durations: paddle.Tensor, + durations_lengths: paddle.Tensor, + pitch: paddle.Tensor, + energy: paddle.Tensor, + sids: Optional[paddle.Tensor]=None, + spembs: Optional[paddle.Tensor]=None, + lids: Optional[paddle.Tensor]=None, + use_alignment_module: bool=False, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, + paddle.Tensor, paddle.Tensor, + Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, + paddle.Tensor, paddle.Tensor, ], ]: + """Calculate forward propagation. + Args: + text (Tensor): + Text index tensor (B, T_text). + text_lengths (Tensor): + Text length tensor (B,). + feats (Tensor): + Feature tensor (B, aux_channels, T_feats). + feats_lengths (Tensor): + Feature length tensor (B,). + pitch (Tensor): + Batch of padded token-averaged pitch (B, T_text, 1). + energy (Tensor): + Batch of padded token-averaged energy (B, T_text, 1). + sids (Optional[Tensor]): + Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): + Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): + Language index tensor (B,) or (B, 1). + use_alignment_module (bool): + Whether to use alignment module. + + Returns: + Tensor: + Waveform tensor (B, 1, segment_size * upsample_factor). + Tensor: + binarization loss () + Tensor: + log probability attention matrix (B,T_feats,T_text) + Tensor: + Segments start index tensor (B,). + Tensor: + predicted duration (B,T_text) + Tensor: + ground-truth duration obtained from an alignment module (B,T_text) + Tensor: + predicted pitch (B,T_text,1) + Tensor: + ground-truth averaged pitch (B,T_text,1) + Tensor: + predicted energy (B,T_text,1) + Tensor: + ground-truth averaged energy (B,T_text,1) + """ + if use_alignment_module: + text = text[:, :text_lengths.max()] # for data-parallel + feats = feats[:, :feats_lengths.max()] # for data-parallel + pitch = pitch[:, :durations_lengths.max()] # for data-parallel + energy = energy[:, :durations_lengths.max()] # for data-parallel + else: + text = text[:, :text_lengths.max()] # for data-parallel + feats = feats[:, :feats_lengths.max()] # for data-parallel + pitch = pitch[:, :feats_lengths.max()] # for data-parallel + energy = energy[:, :feats_lengths.max()] # for data-parallel + + # forward encoder + x_masks = self._source_mask(text_lengths) + hs, _ = self.encoder(text, x_masks) # (B, T_text, adim) + + # integrate with GST + if self.use_gst: + style_embs = self.gst(ys) + hs = hs + style_embs.unsqueeze(1) + + # integrate with SID and LID embeddings + if self.spks is not None: + sid_embs = self.sid_emb(sids.reshape([-1])) + hs = hs + sid_embs.unsqueeze(1) + if self.langs is not None: + lid_embs = self.lid_emb(lids.reshape([-1])) + hs = hs + lid_embs.unsqueeze(1) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + hs = self._integrate_with_spk_embed(hs, spembs) + + # forward alignment module and obtain duration, averaged pitch, energy + h_masks = make_pad_mask(text_lengths) + if use_alignment_module: + log_p_attn = self.alignment_module(hs, feats, h_masks) + ds, bin_loss = viterbi_decode(log_p_attn, text_lengths, + feats_lengths) + ps = average_by_duration(ds, + pitch.squeeze(-1), text_lengths, + feats_lengths).unsqueeze(-1) + es = average_by_duration(ds, + energy.squeeze(-1), text_lengths, + feats_lengths).unsqueeze(-1) + else: + ds = durations + ps = pitch + es = energy + log_p_attn = attn = bin_loss = None + + # forward duration predictor and variance predictors + if self.stop_gradient_from_pitch_predictor: + p_outs = self.pitch_predictor(hs.detach(), h_masks.unsqueeze(-1)) + else: + p_outs = self.pitch_predictor(hs, h_masks.unsqueeze(-1)) + if self.stop_gradient_from_energy_predictor: + e_outs = self.energy_predictor(hs.detach(), h_masks.unsqueeze(-1)) + else: + e_outs = self.energy_predictor(hs, h_masks.unsqueeze(-1)) + + d_outs = self.duration_predictor(hs, h_masks) + + # use groundtruth in training + p_embs = self.pitch_embed(ps.transpose([0, 2, 1])).transpose([0, 2, 1]) + e_embs = self.energy_embed(es.transpose([0, 2, 1])).transpose([0, 2, 1]) + hs = hs + e_embs + p_embs + + # upsampling + h_masks = make_non_pad_mask(feats_lengths) + # d_masks = make_non_pad_mask(text_lengths).to(ds.device) + d_masks = make_non_pad_mask(text_lengths) + hs = self.length_regulator(hs, ds, h_masks, + d_masks) # (B, T_feats, adim) + + # forward decoder + h_masks = self._source_mask(feats_lengths) + zs, _ = self.decoder(hs, h_masks) # (B, T_feats, adim) + + # get random segments + z_segments, z_start_idxs = get_random_segments( + zs.transpose([0, 2, 1]), + feats_lengths, + self.segment_size, ) + # forward generator + wav = self.generator(z_segments) + if use_alignment_module: + return wav, bin_loss, log_p_attn, z_start_idxs, d_outs, ds, p_outs, ps, e_outs, es + else: + return wav, None, None, z_start_idxs, d_outs, ds, p_outs, ps, e_outs, es + + def inference( + self, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + feats: Optional[paddle.Tensor]=None, + feats_lengths: Optional[paddle.Tensor]=None, + pitch: Optional[paddle.Tensor]=None, + energy: Optional[paddle.Tensor]=None, + sids: Optional[paddle.Tensor]=None, + spembs: Optional[paddle.Tensor]=None, + lids: Optional[paddle.Tensor]=None, + use_alignment_module: bool=False, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Run inference. + + Args: + text (Tensor): Input text index tensor (B, T_text,). + text_lengths (Tensor): Text length tensor (B,). + feats (Tensor): Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): Feature length tensor (B,). + pitch (Tensor): Pitch tensor (B, T_feats, 1) + energy (Tensor): Energy tensor (B, T_feats, 1) + sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). + use_alignment_module (bool): Whether to use alignment module. + + Returns: + Tensor: Generated waveform tensor (B, T_wav). + Tensor: Duration tensor (B, T_text). + + """ + # forward encoder + x_masks = self._source_mask(text_lengths) + hs, _ = self.encoder(text, x_masks) # (B, T_text, adim) + + # integrate with GST + if self.use_gst: + style_embs = self.gst(ys) + hs = hs + style_embs.unsqueeze(1) + + # integrate with SID and LID embeddings + if self.spks is not None: + sid_embs = self.sid_emb(sids.view(-1)) + hs = hs + sid_embs.unsqueeze(1) + if self.langs is not None: + lid_embs = self.lid_emb(lids.view(-1)) + hs = hs + lid_embs.unsqueeze(1) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + hs = self._integrate_with_spk_embed(hs, spembs) + + h_masks = make_pad_mask(text_lengths) + if use_alignment_module: + # forward alignment module and obtain duration, averaged pitch, energy + log_p_attn, attn = self.alignment_module(hs, feats, h_masks) + d_outs, _ = viterbi_decode(log_p_attn, text_lengths, feats_lengths) + p_outs = average_by_duration(d_outs, + pitch.squeeze(-1), text_lengths, + feats_lengths).unsqueeze(-1) + e_outs = average_by_duration(d_outs, + energy.squeeze(-1), text_lengths, + feats_lengths).unsqueeze(-1) + else: + # forward duration predictor and variance predictors + p_outs = self.pitch_predictor(hs, h_masks.unsqueeze(-1)) + e_outs = self.energy_predictor(hs, h_masks.unsqueeze(-1)) + d_outs = self.duration_predictor.inference(hs, h_masks) + + p_embs = self.pitch_embed(p_outs.transpose([0, 2, 1])).transpose( + [0, 2, 1]) + e_embs = self.energy_embed(e_outs.transpose([0, 2, 1])).transpose( + [0, 2, 1]) + hs = hs + e_embs + p_embs + + # upsampling + if feats_lengths is not None: + h_masks = make_non_pad_mask(feats_lengths) + else: + h_masks = None + d_masks = make_non_pad_mask(text_lengths) + hs = self.length_regulator(hs, d_outs, h_masks, + d_masks) # (B, T_feats, adim) + + # forward decoder + if feats_lengths is not None: + h_masks = self._source_mask(feats_lengths) + else: + h_masks = None + zs, _ = self.decoder(hs, h_masks) # (B, T_feats, adim) + + # forward generator + wav = self.generator(zs.transpose([0, 2, 1])) + + return wav.squeeze(1), d_outs + + def _integrate_with_spk_embed(self, + hs: paddle.Tensor, + spembs: paddle.Tensor) -> paddle.Tensor: + """Integrate speaker embedding with hidden states. + + Args: + hs (Tensor): Batch of hidden state sequences (B, T_text, adim). + spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Batch of integrated hidden state sequences (B, T_text, adim). + + """ + if self.spk_embed_integration_type == "add": + # apply projection and then add to hidden states + spembs = self.projection(F.normalize(spembs)) + hs = hs + spembs.unsqueeze(1) + elif self.spk_embed_integration_type == "concat": + # concat hidden states with spk embeds and then apply projection + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.shape[1], + -1) + hs = self.projection(paddle.concat([hs, spembs], axis=-1)) + else: + raise NotImplementedError("support only add or concat.") + + return hs + + def _generate_path(self, dur: paddle.Tensor, + mask: paddle.Tensor) -> paddle.Tensor: + """Generate path a.k.a. monotonic attention. + Args: + dur (Tensor): + Duration tensor (B, 1, T_text). + mask (Tensor): + Attention mask tensor (B, 1, T_feats, T_text). + Returns: + Tensor: + Path tensor (B, 1, T_feats, T_text). + """ + b, _, t_y, t_x = paddle.shape(mask) + cum_dur = paddle.cumsum(dur, -1) + cum_dur_flat = paddle.reshape(cum_dur, [b * t_x]) + + path = paddle.arange(t_y, dtype=dur.dtype) + path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1) + path = paddle.reshape(path, [b, t_x, t_y]) + ''' + path will be like (t_x = 3, t_y = 5): + [[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.], + [1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.], + [1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]] + ''' + + path = paddle.cast(path, dtype='float32') + pad_tmp = self.pad1d(path)[:, :-1] + path = path - pad_tmp + return path.unsqueeze(1).transpose([0, 1, 3, 2]) * mask + + def _source_mask(self, ilens: paddle.Tensor) -> paddle.Tensor: + """Make masks for self-attention. + + Args: + ilens (LongTensor): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor for self-attention. + dtype=paddle.uint8 + + Examples: + >>> ilens = [5, 3] + >>> self._source_mask(ilens) + tensor([[[1, 1, 1, 1, 1], + [1, 1, 1, 0, 0]]], dtype=torch.uint8) + + """ + x_masks = paddle.to_tensor(make_non_pad_mask(ilens)) + return x_masks.unsqueeze(-2) + + def _reset_parameters(self, + init_type: str, + init_enc_alpha: float, + init_dec_alpha: float): + # initialize parameters + initialize(self, init_type) + + # initialize alpha in scaled positional encoding + if self.encoder_type == "transformer" and self.use_scaled_pos_enc: + self.encoder.embed[-1].alpha.data = paddle.to_tensor(init_enc_alpha) + if self.decoder_type == "transformer" and self.use_scaled_pos_enc: + self.decoder.embed[-1].alpha.data = paddle.to_tensor(init_dec_alpha) diff --git a/paddlespeech/t2s/models/jets/jets.py b/paddlespeech/t2s/models/jets/jets.py new file mode 100644 index 00000000..4346c65b --- /dev/null +++ b/paddlespeech/t2s/models/jets/jets.py @@ -0,0 +1,582 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generator module in JETS. + +This code is based on https://github.com/imdanboy/jets. + +""" +"""JETS module""" +import math +from typing import Any +from typing import Dict +from typing import Optional + +import paddle +from paddle import nn +from typeguard import check_argument_types + +from paddlespeech.t2s.models.hifigan import HiFiGANMultiPeriodDiscriminator +from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleDiscriminator +from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator +from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator +from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator +from paddlespeech.t2s.models.jets.generator import JETSGenerator +from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out +from paddlespeech.utils.initialize import kaiming_uniform_ +from paddlespeech.utils.initialize import normal_ +from paddlespeech.utils.initialize import ones_ +from paddlespeech.utils.initialize import uniform_ +from paddlespeech.utils.initialize import zeros_ + +AVAILABLE_GENERATERS = { + "jets_generator": JETSGenerator, +} +AVAILABLE_DISCRIMINATORS = { + "hifigan_period_discriminator": + HiFiGANPeriodDiscriminator, + "hifigan_scale_discriminator": + HiFiGANScaleDiscriminator, + "hifigan_multi_period_discriminator": + HiFiGANMultiPeriodDiscriminator, + "hifigan_multi_scale_discriminator": + HiFiGANMultiScaleDiscriminator, + "hifigan_multi_scale_multi_period_discriminator": + HiFiGANMultiScaleMultiPeriodDiscriminator, +} + + +class JETS(nn.Layer): + """JETS module (generator + discriminator). + This is a module of JETS described in `JETS: Jointly Training FastSpeech2 + and HiFi-GAN for End to End Text to Speech`_. + .. _`JETS: Jointly Training FastSpeech2 and HiFi-GAN for End to End Text to Speech + Text-to-Speech`: https://arxiv.org/abs/2203.16852v1 + """ + + def __init__( + self, + # generator related + idim: int, + odim: int, + sampling_rate: int=22050, + generator_type: str="jets_generator", + generator_params: Dict[str, Any]={ + "adim": 256, + "aheads": 2, + "elayers": 4, + "eunits": 1024, + "dlayers": 4, + "dunits": 1024, + "positionwise_layer_type": "conv1d", + "positionwise_conv_kernel_size": 1, + "use_scaled_pos_enc": True, + "use_batch_norm": True, + "encoder_normalize_before": True, + "decoder_normalize_before": True, + "encoder_concat_after": False, + "decoder_concat_after": False, + "reduction_factor": 1, + "encoder_type": "transformer", + "decoder_type": "transformer", + "transformer_enc_dropout_rate": 0.1, + "transformer_enc_positional_dropout_rate": 0.1, + "transformer_enc_attn_dropout_rate": 0.1, + "transformer_dec_dropout_rate": 0.1, + "transformer_dec_positional_dropout_rate": 0.1, + "transformer_dec_attn_dropout_rate": 0.1, + "conformer_rel_pos_type": "latest", + "conformer_pos_enc_layer_type": "rel_pos", + "conformer_self_attn_layer_type": "rel_selfattn", + "conformer_activation_type": "swish", + "use_macaron_style_in_conformer": True, + "use_cnn_in_conformer": True, + "zero_triu": False, + "conformer_enc_kernel_size": 7, + "conformer_dec_kernel_size": 31, + "duration_predictor_layers": 2, + "duration_predictor_chans": 384, + "duration_predictor_kernel_size": 3, + "duration_predictor_dropout_rate": 0.1, + "energy_predictor_layers": 2, + "energy_predictor_chans": 384, + "energy_predictor_kernel_size": 3, + "energy_predictor_dropout": 0.5, + "energy_embed_kernel_size": 1, + "energy_embed_dropout": 0.5, + "stop_gradient_from_energy_predictor": False, + "pitch_predictor_layers": 5, + "pitch_predictor_chans": 384, + "pitch_predictor_kernel_size": 5, + "pitch_predictor_dropout": 0.5, + "pitch_embed_kernel_size": 1, + "pitch_embed_dropout": 0.5, + "stop_gradient_from_pitch_predictor": True, + "generator_out_channels": 1, + "generator_channels": 512, + "generator_global_channels": -1, + "generator_kernel_size": 7, + "generator_upsample_scales": [8, 8, 2, 2], + "generator_upsample_kernel_sizes": [16, 16, 4, 4], + "generator_resblock_kernel_sizes": [3, 7, 11], + "generator_resblock_dilations": + [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "generator_use_additional_convs": True, + "generator_bias": True, + "generator_nonlinear_activation": "LeakyReLU", + "generator_nonlinear_activation_params": { + "negative_slope": 0.1 + }, + "generator_use_weight_norm": True, + "segment_size": 64, + "spks": -1, + "langs": -1, + "spk_embed_dim": None, + "spk_embed_integration_type": "add", + "use_gst": False, + "gst_tokens": 10, + "gst_heads": 4, + "gst_conv_layers": 6, + "gst_conv_chans_list": [32, 32, 64, 64, 128, 128], + "gst_conv_kernel_size": 3, + "gst_conv_stride": 2, + "gst_gru_layers": 1, + "gst_gru_units": 128, + "init_type": "xavier_uniform", + "init_enc_alpha": 1.0, + "init_dec_alpha": 1.0, + "use_masking": False, + "use_weighted_masking": False, + }, + # discriminator related + discriminator_type: str="hifigan_multi_scale_multi_period_discriminator", + discriminator_params: Dict[str, Any]={ + "scales": 1, + "scale_downsample_pooling": "AvgPool1D", + "scale_downsample_pooling_params": { + "kernel_size": 4, + "stride": 2, + "padding": 2, + }, + "scale_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [15, 41, 5, 3], + "channels": 128, + "max_downsample_channels": 1024, + "max_groups": 16, + "bias": True, + "downsample_scales": [2, 2, 4, 4, 1], + "nonlinear_activation": "leakyrelu", + "nonlinear_activation_params": { + "negative_slope": 0.1 + }, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + "follow_official_norm": False, + "periods": [2, 3, 5, 7, 11], + "period_discriminator_params": { + "in_channels": 1, + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 32, + "downsample_scales": [3, 3, 3, 3, 1], + "max_downsample_channels": 1024, + "bias": True, + "nonlinear_activation": "leakyrelu", + "nonlinear_activation_params": { + "negative_slope": 0.1 + }, + "use_weight_norm": True, + "use_spectral_norm": False, + }, + }, + cache_generator_outputs: bool=True, ): + """Initialize JETS module. + Args: + idim (int): + Input vocabrary size. + odim (int): + Acoustic feature dimension. The actual output channels will + be 1 since JETS is the end-to-end text-to-wave model but for the + compatibility odim is used to indicate the acoustic feature dimension. + sampling_rate (int): + Sampling rate, not used for the training but it will + be referred in saving waveform during the inference. + generator_type (str): + Generator type. + generator_params (Dict[str, Any]): + Parameter dict for generator. + discriminator_type (str): + Discriminator type. + discriminator_params (Dict[str, Any]): + Parameter dict for discriminator. + cache_generator_outputs (bool): + Whether to cache generator outputs. + """ + assert check_argument_types() + super().__init__() + + # define modules + generator_class = AVAILABLE_GENERATERS[generator_type] + if generator_type == "jets_generator": + # NOTE: Update parameters for the compatibility. + # The idim and odim is automatically decided from input data, + # where idim represents #vocabularies and odim represents + # the input acoustic feature dimension. + generator_params.update(idim=idim, odim=odim) + self.generator = generator_class( + **generator_params, ) + discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] + self.discriminator = discriminator_class( + **discriminator_params, ) + + # cache + self.cache_generator_outputs = cache_generator_outputs + self._cache = None + + # store sampling rate for saving wav file + # (not used for the training) + self.fs = sampling_rate + + # store parameters for test compatibility + self.spks = self.generator.spks + self.langs = self.generator.langs + self.spk_embed_dim = self.generator.spk_embed_dim + + self.reuse_cache_gen = True + self.reuse_cache_dis = True + + self.reset_parameters() + self.generator._reset_parameters( + init_type=generator_params["init_type"], + init_enc_alpha=generator_params["init_enc_alpha"], + init_dec_alpha=generator_params["init_dec_alpha"], ) + + def forward( + self, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + feats: paddle.Tensor, + feats_lengths: paddle.Tensor, + durations: paddle.Tensor, + durations_lengths: paddle.Tensor, + pitch: paddle.Tensor, + energy: paddle.Tensor, + sids: Optional[paddle.Tensor]=None, + spembs: Optional[paddle.Tensor]=None, + lids: Optional[paddle.Tensor]=None, + forward_generator: bool=True, + use_alignment_module: bool=False, + **kwargs, + ) -> Dict[str, Any]: + """Perform generator forward. + Args: + text (Tensor): + Text index tensor (B, T_text). + text_lengths (Tensor): + Text length tensor (B,). + feats (Tensor): + Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): + Feature length tensor (B,). + durations(Tensor(int64)): + Batch of padded durations (B, Tmax). + durations_lengths (Tensor): + durations length tensor (B,). + pitch(Tensor): + Batch of padded token-averaged pitch (B, Tmax, 1). + energy(Tensor): + Batch of padded token-averaged energy (B, Tmax, 1). + sids (Optional[Tensor]): + Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): + Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): + Language index tensor (B,) or (B, 1). + forward_generator (bool): + Whether to forward generator. + use_alignment_module (bool): + Whether to use alignment module. + Returns: + + """ + if forward_generator: + return self._forward_generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + durations=durations, + durations_lengths=durations_lengths, + pitch=pitch, + energy=energy, + sids=sids, + spembs=spembs, + lids=lids, + use_alignment_module=use_alignment_module, ) + else: + return self._forward_discrminator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + durations=durations, + durations_lengths=durations_lengths, + pitch=pitch, + energy=energy, + sids=sids, + spembs=spembs, + lids=lids, + use_alignment_module=use_alignment_module, ) + + def _forward_generator( + self, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + feats: paddle.Tensor, + feats_lengths: paddle.Tensor, + durations: paddle.Tensor, + durations_lengths: paddle.Tensor, + pitch: paddle.Tensor, + energy: paddle.Tensor, + sids: Optional[paddle.Tensor]=None, + spembs: Optional[paddle.Tensor]=None, + lids: Optional[paddle.Tensor]=None, + use_alignment_module: bool=False, + **kwargs, ) -> Dict[str, Any]: + """Perform generator forward. + Args: + text (Tensor): + Text index tensor (B, T_text). + text_lengths (Tensor): + Text length tensor (B,). + feats (Tensor): + Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): + Feature length tensor (B,). + durations(Tensor(int64)): + Batch of padded durations (B, Tmax). + durations_lengths (Tensor): + durations length tensor (B,). + pitch(Tensor): + Batch of padded token-averaged pitch (B, Tmax, 1). + energy(Tensor): + Batch of padded token-averaged energy (B, Tmax, 1). + sids (Optional[Tensor]): + Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): + Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): + Language index tensor (B,) or (B, 1). + use_alignment_module (bool): + Whether to use alignment module. + Returns: + + """ + # setup + # calculate generator outputs + self.reuse_cache_gen = True + if not self.cache_generator_outputs or self._cache is None: + self.reuse_cache_gen = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + durations=durations, + durations_lengths=durations_lengths, + pitch=pitch, + energy=energy, + sids=sids, + spembs=spembs, + lids=lids, + use_alignment_module=use_alignment_module, ) + else: + outs = self._cache + + # store cache + if self.training and self.cache_generator_outputs and not self.reuse_cache_gen: + self._cache = outs + + return outs + + def _forward_discrminator( + self, + text: paddle.Tensor, + text_lengths: paddle.Tensor, + feats: paddle.Tensor, + feats_lengths: paddle.Tensor, + durations: paddle.Tensor, + durations_lengths: paddle.Tensor, + pitch: paddle.Tensor, + energy: paddle.Tensor, + sids: Optional[paddle.Tensor]=None, + spembs: Optional[paddle.Tensor]=None, + lids: Optional[paddle.Tensor]=None, + use_alignment_module: bool=False, + **kwargs, ) -> Dict[str, Any]: + """Perform discriminator forward. + Args: + text (Tensor): + Text index tensor (B, T_text). + text_lengths (Tensor): + Text length tensor (B,). + feats (Tensor): + Feature tensor (B, T_feats, aux_channels). + feats_lengths (Tensor): + Feature length tensor (B,). + durations(Tensor(int64)): + Batch of padded durations (B, Tmax). + durations_lengths (Tensor): + durations length tensor (B,). + pitch(Tensor): + Batch of padded token-averaged pitch (B, Tmax, 1). + energy(Tensor): + Batch of padded token-averaged energy (B, Tmax, 1). + sids (Optional[Tensor]): + Speaker index tensor (B,) or (B, 1). + spembs (Optional[Tensor]): + Speaker embedding tensor (B, spk_embed_dim). + lids (Optional[Tensor]): + Language index tensor (B,) or (B, 1). + use_alignment_module (bool): + Whether to use alignment module. + Returns: + + """ + # setup + # calculate generator outputs + self.reuse_cache_dis = True + if not self.cache_generator_outputs or self._cache is None: + self.reuse_cache_dis = False + outs = self.generator( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + durations=durations, + durations_lengths=durations_lengths, + pitch=pitch, + energy=energy, + sids=sids, + spembs=spembs, + lids=lids, + use_alignment_module=use_alignment_module, + **kwargs, ) + else: + outs = self._cache + + # store cache + if self.cache_generator_outputs and not self.reuse_cache_dis: + self._cache = outs + + return outs + + def inference(self, + text: paddle.Tensor, + feats: Optional[paddle.Tensor]=None, + pitch: Optional[paddle.Tensor]=None, + energy: Optional[paddle.Tensor]=None, + use_alignment_module: bool=False, + **kwargs) -> Dict[str, paddle.Tensor]: + """Run inference. + Args: + text (Tensor): + Input text index tensor (T_text,). + feats (Tensor): + Feature tensor (T_feats, aux_channels). + pitch (Tensor): + Pitch tensor (T_feats, 1). + energy (Tensor): + Energy tensor (T_feats, 1). + use_alignment_module (bool): + Whether to use alignment module. + Returns: + Dict[str, Tensor]: + * wav (Tensor): + Generated waveform tensor (T_wav,). + * duration (Tensor): + Predicted duration tensor (T_text,). + """ + # setup + text = text[None] + text_lengths = paddle.to_tensor(paddle.shape(text)[1]) + + # inference + if use_alignment_module: + assert feats is not None + feats = feats[None] + feats_lengths = paddle.to_tensor(paddle.shape(feats)[1]) + pitch = pitch[None] + energy = energy[None] + wav, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + feats=feats, + feats_lengths=feats_lengths, + pitch=pitch, + energy=energy, + use_alignment_module=use_alignment_module, + **kwargs) + else: + wav, dur = self.generator.inference( + text=text, + text_lengths=text_lengths, + **kwargs, ) + return dict(wav=paddle.reshape(wav, [-1]), duration=dur[0]) + + def reset_parameters(self): + def _reset_parameters(module): + if isinstance( + module, + (nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)): + kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + uniform_(module.bias, -bound, bound) + + if isinstance( + module, + (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)): + ones_(module.weight) + zeros_(module.bias) + + if isinstance(module, nn.Linear): + kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + uniform_(module.bias, -bound, bound) + + if isinstance(module, nn.Embedding): + normal_(module.weight) + if module._padding_idx is not None: + with paddle.no_grad(): + module.weight[module._padding_idx] = 0 + + self.apply(_reset_parameters) + + +class JETSInference(nn.Layer): + def __init__(self, model): + super().__init__() + self.acoustic_model = model + + def forward(self, text, sids=None): + out = self.acoustic_model.inference(text) + wav = out['wav'] + return wav diff --git a/paddlespeech/t2s/models/jets/jets_updater.py b/paddlespeech/t2s/models/jets/jets_updater.py new file mode 100644 index 00000000..a82ac85c --- /dev/null +++ b/paddlespeech/t2s/models/jets/jets_updater.py @@ -0,0 +1,437 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generator module in JETS. + +This code is based on https://github.com/imdanboy/jets. + +""" +import logging +from typing import Dict + +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler + +from paddlespeech.t2s.modules.nets_utils import get_segments +from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator +from paddlespeech.t2s.training.reporter import report +from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater +from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState + +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class JETSUpdater(StandardUpdater): + def __init__(self, + model: Layer, + optimizers: Dict[str, Optimizer], + criterions: Dict[str, Layer], + schedulers: Dict[str, LRScheduler], + dataloader: DataLoader, + generator_train_start_steps: int=0, + discriminator_train_start_steps: int=100000, + lambda_adv: float=1.0, + lambda_mel: float=45.0, + lambda_feat_match: float=2.0, + lambda_var: float=1.0, + lambda_align: float=2.0, + generator_first: bool=False, + use_alignment_module: bool=False, + output_dir=None): + # it is designed to hold multiple models + # 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分 + models = {"main": model} + self.models: Dict[str, Layer] = models + # self.model = model + + self.model = model._layers if isinstance(model, + paddle.DataParallel) else model + + self.optimizers = optimizers + self.optimizer_g: Optimizer = optimizers['generator'] + self.optimizer_d: Optimizer = optimizers['discriminator'] + + self.criterions = criterions + self.criterion_mel = criterions['mel'] + self.criterion_feat_match = criterions['feat_match'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + self.criterion_var = criterions["var"] + self.criterion_forwardsum = criterions["forwardsum"] + + self.schedulers = schedulers + self.scheduler_g = schedulers['generator'] + self.scheduler_d = schedulers['discriminator'] + + self.dataloader = dataloader + + self.generator_train_start_steps = generator_train_start_steps + self.discriminator_train_start_steps = discriminator_train_start_steps + + self.lambda_adv = lambda_adv + self.lambda_mel = lambda_mel + self.lambda_feat_match = lambda_feat_match + self.lambda_var = lambda_var + self.lambda_align = lambda_align + + self.use_alignment_module = use_alignment_module + + if generator_first: + self.turns = ["generator", "discriminator"] + else: + self.turns = ["discriminator", "generator"] + + self.state = UpdaterState(iteration=0, epoch=0) + self.train_iterator = iter(self.dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def update_core(self, batch): + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + for turn in self.turns: + speech = batch["speech"] + speech = speech.unsqueeze(1) + text_lengths = batch["text_lengths"] + feats_lengths = batch["feats_lengths"] + outs = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + feats=batch["feats"], + feats_lengths=batch["feats_lengths"], + durations=batch["durations"], + durations_lengths=batch["durations_lengths"], + pitch=batch["pitch"], + energy=batch["energy"], + sids=batch.get("spk_id", None), + spembs=batch.get("spk_emb", None), + forward_generator=turn == "generator", + use_alignment_module=self.use_alignment_module) + # Generator + if turn == "generator": + # parse outputs + speech_hat_, bin_loss, log_p_attn, start_idxs, d_outs, ds, p_outs, ps, e_outs, es = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * + self.model.generator.upsample_factor, + segment_size=self.model.generator.segment_size * + self.model.generator.upsample_factor, ) + + # calculate discriminator outputs + p_hat = self.model.discriminator(speech_hat_) + with paddle.no_grad(): + # do not store discriminator gradient in generator turn + p = self.model.discriminator(speech_) + + # calculate losses + mel_loss = self.criterion_mel(speech_hat_, speech_) + + adv_loss = self.criterion_gen_adv(p_hat) + feat_match_loss = self.criterion_feat_match(p_hat, p) + dur_loss, pitch_loss, energy_loss = self.criterion_var( + d_outs, ds, p_outs, ps, e_outs, es, text_lengths) + + mel_loss = mel_loss * self.lambda_mel + adv_loss = adv_loss * self.lambda_adv + feat_match_loss = feat_match_loss * self.lambda_feat_match + g_loss = mel_loss + adv_loss + feat_match_loss + var_loss = ( + dur_loss + pitch_loss + energy_loss) * self.lambda_var + + gen_loss = g_loss + var_loss #+ align_loss + + report("train/generator_loss", float(gen_loss)) + report("train/generator_generator_loss", float(g_loss)) + report("train/generator_variance_loss", float(var_loss)) + report("train/generator_generator_mel_loss", float(mel_loss)) + report("train/generator_generator_adv_loss", float(adv_loss)) + report("train/generator_generator_feat_match_loss", + float(feat_match_loss)) + report("train/generator_variance_dur_loss", float(dur_loss)) + report("train/generator_variance_pitch_loss", float(pitch_loss)) + report("train/generator_variance_energy_loss", + float(energy_loss)) + + losses_dict["generator_loss"] = float(gen_loss) + losses_dict["generator_generator_loss"] = float(g_loss) + losses_dict["generator_variance_loss"] = float(var_loss) + losses_dict["generator_generator_mel_loss"] = float(mel_loss) + losses_dict["generator_generator_adv_loss"] = float(adv_loss) + losses_dict["generator_generator_feat_match_loss"] = float( + feat_match_loss) + losses_dict["generator_variance_dur_loss"] = float(dur_loss) + losses_dict["generator_variance_pitch_loss"] = float(pitch_loss) + losses_dict["generator_variance_energy_loss"] = float( + energy_loss) + + if self.use_alignment_module == True: + forwardsum_loss = self.criterion_forwardsum( + log_p_attn, text_lengths, feats_lengths) + align_loss = ( + forwardsum_loss + bin_loss) * self.lambda_align + report("train/generator_alignment_loss", float(align_loss)) + report("train/generator_alignment_forwardsum_loss", + float(forwardsum_loss)) + report("train/generator_alignment_bin_loss", + float(bin_loss)) + losses_dict["generator_alignment_loss"] = float(align_loss) + losses_dict["generator_alignment_forwardsum_loss"] = float( + forwardsum_loss) + losses_dict["generator_alignment_bin_loss"] = float( + bin_loss) + + self.optimizer_g.clear_grad() + gen_loss.backward() + + self.optimizer_g.step() + self.scheduler_g.step() + + # reset cache + if self.model.reuse_cache_gen or not self.model.training: + self.model._cache = None + + # Disctiminator + elif turn == "discriminator": + # parse outputs + speech_hat_, _, _, start_idxs, *_ = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * + self.model.generator.upsample_factor, + segment_size=self.model.generator.segment_size * + self.model.generator.upsample_factor, ) + + # calculate discriminator outputs + p_hat = self.model.discriminator(speech_hat_.detach()) + p = self.model.discriminator(speech_) + + # calculate losses + real_loss, fake_loss = self.criterion_dis_adv(p_hat, p) + dis_loss = real_loss + fake_loss + + report("train/real_loss", float(real_loss)) + report("train/fake_loss", float(fake_loss)) + report("train/discriminator_loss", float(dis_loss)) + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.optimizer_d.clear_grad() + dis_loss.backward() + + self.optimizer_d.step() + self.scheduler_d.step() + + # reset cache + if self.model.reuse_cache_dis or not self.model.training: + self.model._cache = None + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + +class JETSEvaluator(StandardEvaluator): + def __init__(self, + model, + criterions: Dict[str, Layer], + dataloader: DataLoader, + lambda_adv: float=1.0, + lambda_mel: float=45.0, + lambda_feat_match: float=2.0, + lambda_var: float=1.0, + lambda_align: float=2.0, + generator_first: bool=False, + use_alignment_module: bool=False, + output_dir=None): + # 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分 + models = {"main": model} + self.models: Dict[str, Layer] = models + # self.model = model + self.model = model._layers if isinstance(model, + paddle.DataParallel) else model + + self.criterions = criterions + self.criterion_mel = criterions['mel'] + self.criterion_feat_match = criterions['feat_match'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + self.criterion_var = criterions["var"] + self.criterion_forwardsum = criterions["forwardsum"] + + self.dataloader = dataloader + + self.lambda_adv = lambda_adv + self.lambda_mel = lambda_mel + self.lambda_feat_match = lambda_feat_match + self.lambda_var = lambda_var + self.lambda_align = lambda_align + self.use_alignment_module = use_alignment_module + + if generator_first: + self.turns = ["generator", "discriminator"] + else: + self.turns = ["discriminator", "generator"] + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def evaluate_core(self, batch): + # logging.debug("Evaluate: ") + self.msg = "Evaluate: " + losses_dict = {} + + for turn in self.turns: + speech = batch["speech"] + speech = speech.unsqueeze(1) + text_lengths = batch["text_lengths"] + feats_lengths = batch["feats_lengths"] + outs = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + feats=batch["feats"], + feats_lengths=batch["feats_lengths"], + durations=batch["durations"], + durations_lengths=batch["durations_lengths"], + pitch=batch["pitch"], + energy=batch["energy"], + sids=batch.get("spk_id", None), + spembs=batch.get("spk_emb", None), + forward_generator=turn == "generator", + use_alignment_module=self.use_alignment_module) + # Generator + if turn == "generator": + # parse outputs + speech_hat_, bin_loss, log_p_attn, start_idxs, d_outs, ds, p_outs, ps, e_outs, es = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * + self.model.generator.upsample_factor, + segment_size=self.model.generator.segment_size * + self.model.generator.upsample_factor, ) + + # calculate discriminator outputs + p_hat = self.model.discriminator(speech_hat_) + with paddle.no_grad(): + # do not store discriminator gradient in generator turn + p = self.model.discriminator(speech_) + + # calculate losses + mel_loss = self.criterion_mel(speech_hat_, speech_) + + adv_loss = self.criterion_gen_adv(p_hat) + feat_match_loss = self.criterion_feat_match(p_hat, p) + dur_loss, pitch_loss, energy_loss = self.criterion_var( + d_outs, ds, p_outs, ps, e_outs, es, text_lengths) + + mel_loss = mel_loss * self.lambda_mel + adv_loss = adv_loss * self.lambda_adv + feat_match_loss = feat_match_loss * self.lambda_feat_match + g_loss = mel_loss + adv_loss + feat_match_loss + var_loss = ( + dur_loss + pitch_loss + energy_loss) * self.lambda_var + + gen_loss = g_loss + var_loss #+ align_loss + + report("eval/generator_loss", float(gen_loss)) + report("eval/generator_generator_loss", float(g_loss)) + report("eval/generator_variance_loss", float(var_loss)) + report("eval/generator_generator_mel_loss", float(mel_loss)) + report("eval/generator_generator_adv_loss", float(adv_loss)) + report("eval/generator_generator_feat_match_loss", + float(feat_match_loss)) + report("eval/generator_variance_dur_loss", float(dur_loss)) + report("eval/generator_variance_pitch_loss", float(pitch_loss)) + report("eval/generator_variance_energy_loss", + float(energy_loss)) + + losses_dict["generator_loss"] = float(gen_loss) + losses_dict["generator_generator_loss"] = float(g_loss) + losses_dict["generator_variance_loss"] = float(var_loss) + losses_dict["generator_generator_mel_loss"] = float(mel_loss) + losses_dict["generator_generator_adv_loss"] = float(adv_loss) + losses_dict["generator_generator_feat_match_loss"] = float( + feat_match_loss) + losses_dict["generator_variance_dur_loss"] = float(dur_loss) + losses_dict["generator_variance_pitch_loss"] = float(pitch_loss) + losses_dict["generator_variance_energy_loss"] = float( + energy_loss) + + if self.use_alignment_module == True: + forwardsum_loss = self.criterion_forwardsum( + log_p_attn, text_lengths, feats_lengths) + align_loss = ( + forwardsum_loss + bin_loss) * self.lambda_align + report("eval/generator_alignment_loss", float(align_loss)) + report("eval/generator_alignment_forwardsum_loss", + float(forwardsum_loss)) + report("eval/generator_alignment_bin_loss", float(bin_loss)) + losses_dict["generator_alignment_loss"] = float(align_loss) + losses_dict["generator_alignment_forwardsum_loss"] = float( + forwardsum_loss) + losses_dict["generator_alignment_bin_loss"] = float( + bin_loss) + + # reset cache + if self.model.reuse_cache_gen or not self.model.training: + self.model._cache = None + + # Disctiminator + elif turn == "discriminator": + # parse outputs + speech_hat_, _, _, start_idxs, *_ = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * + self.model.generator.upsample_factor, + segment_size=self.model.generator.segment_size * + self.model.generator.upsample_factor, ) + + # calculate discriminator outputs + p_hat = self.model.discriminator(speech_hat_.detach()) + p = self.model.discriminator(speech_) + + # calculate losses + real_loss, fake_loss = self.criterion_dis_adv(p_hat, p) + dis_loss = real_loss + fake_loss + + report("eval/real_loss", float(real_loss)) + report("eval/fake_loss", float(fake_loss)) + report("eval/discriminator_loss", float(dis_loss)) + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + # reset cache + if self.model.reuse_cache_dis or not self.model.training: + self.model._cache = None + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) diff --git a/paddlespeech/t2s/models/jets/length_regulator.py b/paddlespeech/t2s/models/jets/length_regulator.py new file mode 100644 index 00000000..f7a395a6 --- /dev/null +++ b/paddlespeech/t2s/models/jets/length_regulator.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Generator module in JETS. + +This code is based on https://github.com/imdanboy/jets. + +""" +import paddle +import paddle.nn.functional as F +from paddle import nn + +from paddlespeech.t2s.modules.masked_fill import masked_fill + + +class GaussianUpsampling(nn.Layer): + """ + Gaussian upsampling with fixed temperature as in: + https://arxiv.org/abs/2010.04301 + """ + + def __init__(self, delta=0.1): + super().__init__() + self.delta = delta + + def forward(self, hs, ds, h_masks=None, d_masks=None): + """ + Args: + hs (Tensor): Batched hidden state to be expanded (B, T_text, adim) + ds (Tensor): Batched token duration (B, T_text) + h_masks (Tensor): Mask tensor (B,T_feats) + d_masks (Tensor): Mask tensor (B,T_text) + Returns: + Tensor: Expanded hidden state (B, T_feat, adim) + """ + B = ds.shape[0] + + if h_masks is None: + T_feats = paddle.to_tensor(ds.sum(), dtype="int32") + else: + T_feats = h_masks.shape[-1] + t = paddle.to_tensor( + paddle.arange(0, T_feats).unsqueeze(0).tile([B, 1]), + dtype="float32") + if h_masks is not None: + t = t * paddle.to_tensor(h_masks, dtype="float32") + + c = ds.cumsum(axis=-1) - ds / 2 + energy = -1 * self.delta * (t.unsqueeze(-1) - c.unsqueeze(1))**2 + if d_masks is not None: + d_masks = ~(d_masks.unsqueeze(1)) + d_masks.stop_gradient = True + d_masks = d_masks.tile([1, T_feats, 1]) + energy = masked_fill(energy, d_masks, -float("inf")) + p_attn = F.softmax(energy, axis=2) # (B, T_feats, T_text) + hs = paddle.matmul(p_attn, hs) + return hs diff --git a/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py index 25197457..85b3453d 100644 --- a/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py +++ b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py @@ -22,6 +22,7 @@ from .layers import ConvBlock from .layers import ConvNorm from .layers import LinearNorm from .layers import MFCC +from paddlespeech.t2s.modules.nets_utils import _reset_parameters from paddlespeech.utils.initialize import uniform_ @@ -59,6 +60,9 @@ class ASRCNN(nn.Layer): hidden_dim=hidden_dim // 2, n_token=n_token) + self.reset_parameters() + self.asr_s2s.reset_parameters() + def forward(self, x: paddle.Tensor, src_key_padding_mask: paddle.Tensor=None, @@ -108,6 +112,9 @@ class ASRCNN(nn.Layer): index_tensor.T + unmask_future_steps) return mask + def reset_parameters(self): + self.apply(_reset_parameters) + class ASRS2S(nn.Layer): def __init__(self, @@ -118,8 +125,7 @@ class ASRS2S(nn.Layer): n_token: int=40): super().__init__() self.embedding = nn.Embedding(n_token, embedding_dim) - val_range = math.sqrt(6 / hidden_dim) - uniform_(self.embedding.weight, -val_range, val_range) + self.val_range = math.sqrt(6 / hidden_dim) self.decoder_rnn_dim = hidden_dim self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token) @@ -236,3 +242,6 @@ class ASRS2S(nn.Layer): hidden = paddle.stack(hidden).transpose([1, 0, 2]) return hidden, logit, alignments + + def reset_parameters(self): + uniform_(self.embedding.weight, -self.val_range, self.val_range) diff --git a/paddlespeech/t2s/models/starganv2_vc/losses.py b/paddlespeech/t2s/models/starganv2_vc/losses.py index 8086a595..d94c9342 100644 --- a/paddlespeech/t2s/models/starganv2_vc/losses.py +++ b/paddlespeech/t2s/models/starganv2_vc/losses.py @@ -11,92 +11,102 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any +from typing import Dict + import paddle import paddle.nn.functional as F -from munch import Munch -from starganv2vc_paddle.transforms import build_transforms +from .transforms import build_transforms # 这些都写到 updater 里 -def compute_d_loss(nets, - args, - x_real, - y_org, - y_trg, - z_trg=None, - x_ref=None, - use_r1_reg=True, - use_adv_cls=False, - use_con_reg=False): - args = Munch(args) + + +def compute_d_loss( + nets: Dict[str, Any], + x_real: paddle.Tensor, + y_org: paddle.Tensor, + y_trg: paddle.Tensor, + z_trg: paddle.Tensor=None, + x_ref: paddle.Tensor=None, + # TODO: should be True here, but r1_reg has some bug now + use_r1_reg: bool=False, + use_adv_cls: bool=False, + use_con_reg: bool=False, + lambda_reg: float=1., + lambda_adv_cls: float=0.1, + lambda_con_reg: float=10.): assert (z_trg is None) != (x_ref is None) # with real audios x_real.stop_gradient = False - out = nets.discriminator(x_real, y_org) + out = nets['discriminator'](x_real, y_org) loss_real = adv_loss(out, 1) - # R1 regularizaition (https://arxiv.org/abs/1801.04406v4) if use_r1_reg: loss_reg = r1_reg(out, x_real) else: - loss_reg = paddle.to_tensor([0.], dtype=paddle.float32) + # loss_reg = paddle.to_tensor([0.], dtype=paddle.float32) + loss_reg = paddle.zeros([1]) # consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724) - loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32) + loss_con_reg = paddle.zeros([1]) if use_con_reg: t = build_transforms() - out_aug = nets.discriminator(t(x_real).detach(), y_org) + out_aug = nets['discriminator'](t(x_real).detach(), y_org) loss_con_reg += F.smooth_l1_loss(out, out_aug) # with fake audios with paddle.no_grad(): if z_trg is not None: - s_trg = nets.mapping_network(z_trg, y_trg) + s_trg = nets['mapping_network'](z_trg, y_trg) else: # x_ref is not None - s_trg = nets.style_encoder(x_ref, y_trg) + s_trg = nets['style_encoder'](x_ref, y_trg) - F0 = nets.f0_model.get_feature_GAN(x_real) - x_fake = nets.generator(x_real, s_trg, masks=None, F0=F0) - out = nets.discriminator(x_fake, y_trg) + F0 = nets['F0_model'].get_feature_GAN(x_real) + x_fake = nets['generator'](x_real, s_trg, masks=None, F0=F0) + out = nets['discriminator'](x_fake, y_trg) loss_fake = adv_loss(out, 0) if use_con_reg: - out_aug = nets.discriminator(t(x_fake).detach(), y_trg) + out_aug = nets['discriminator'](t(x_fake).detach(), y_trg) loss_con_reg += F.smooth_l1_loss(out, out_aug) # adversarial classifier loss if use_adv_cls: - out_de = nets.discriminator.classifier(x_fake) + out_de = nets['discriminator'].classifier(x_fake) loss_real_adv_cls = F.cross_entropy(out_de[y_org != y_trg], y_org[y_org != y_trg]) if use_con_reg: - out_de_aug = nets.discriminator.classifier(t(x_fake).detach()) + out_de_aug = nets['discriminator'].classifier(t(x_fake).detach()) loss_con_reg += F.smooth_l1_loss(out_de, out_de_aug) else: loss_real_adv_cls = paddle.zeros([1]).mean() - loss = loss_real + loss_fake + args.lambda_reg * loss_reg + \ - args.lambda_adv_cls * loss_real_adv_cls + \ - args.lambda_con_reg * loss_con_reg + loss = loss_real + loss_fake + lambda_reg * loss_reg + \ + lambda_adv_cls * loss_real_adv_cls + \ + lambda_con_reg * loss_con_reg - return loss, Munch( - real=loss_real.item(), - fake=loss_fake.item(), - reg=loss_reg.item(), - real_adv_cls=loss_real_adv_cls.item(), - con_reg=loss_con_reg.item()) + return loss -def compute_g_loss(nets, - args, - x_real, - y_org, - y_trg, - z_trgs=None, - x_refs=None, - use_adv_cls=False): - args = Munch(args) +def compute_g_loss(nets: Dict[str, Any], + x_real: paddle.Tensor, + y_org: paddle.Tensor, + y_trg: paddle.Tensor, + z_trgs: paddle.Tensor=None, + x_refs: paddle.Tensor=None, + use_adv_cls: bool=False, + lambda_sty: float=1., + lambda_cyc: float=5., + lambda_ds: float=1., + lambda_norm: float=1., + lambda_asr: float=10., + lambda_f0: float=5., + lambda_f0_sty: float=0.1, + lambda_adv: float=2., + lambda_adv_cls: float=0.5, + norm_bias: float=0.5): assert (z_trgs is None) != (x_refs is None) if z_trgs is not None: @@ -106,37 +116,37 @@ def compute_g_loss(nets, # compute style vectors if z_trgs is not None: - s_trg = nets.mapping_network(z_trg, y_trg) + s_trg = nets['mapping_network'](z_trg, y_trg) else: - s_trg = nets.style_encoder(x_ref, y_trg) + s_trg = nets['style_encoder'](x_ref, y_trg) # compute ASR/F0 features (real) - with paddle.no_grad(): - F0_real, GAN_F0_real, cyc_F0_real = nets.f0_model(x_real) - ASR_real = nets.asr_model.get_feature(x_real) + # 源码没有用 .eval(), 使用了 no_grad() + # 我们使用了 .eval(), 开启 with paddle.no_grad() 会报错 + F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real) + ASR_real = nets['asr_model'].get_feature(x_real) # adversarial loss - x_fake = nets.generator(x_real, s_trg, masks=None, F0=GAN_F0_real) - out = nets.discriminator(x_fake, y_trg) + x_fake = nets['generator'](x_real, s_trg, masks=None, F0=GAN_F0_real) + out = nets['discriminator'](x_fake, y_trg) loss_adv = adv_loss(out, 1) # compute ASR/F0 features (fake) - F0_fake, GAN_F0_fake, _ = nets.f0_model(x_fake) - ASR_fake = nets.asr_model.get_feature(x_fake) + F0_fake, GAN_F0_fake, _ = nets['F0_model'](x_fake) + ASR_fake = nets['asr_model'].get_feature(x_fake) # norm consistency loss x_fake_norm = log_norm(x_fake) x_real_norm = log_norm(x_real) - loss_norm = (( - paddle.nn.ReLU()(paddle.abs(x_fake_norm - x_real_norm) - args.norm_bias) - )**2).mean() + tmp = paddle.abs(x_fake_norm - x_real_norm) - norm_bias + loss_norm = ((paddle.nn.ReLU()(tmp))**2).mean() # F0 loss loss_f0 = f0_loss(F0_fake, F0_real) # style F0 loss (style initialization) - if x_refs is not None and args.lambda_f0_sty > 0 and not use_adv_cls: - F0_sty, _, _ = nets.f0_model(x_ref) + if x_refs is not None and lambda_f0_sty > 0 and not use_adv_cls: + F0_sty, _, _ = nets['F0_model'](x_ref) loss_f0_sty = F.l1_loss( compute_mean_f0(F0_fake), compute_mean_f0(F0_sty)) else: @@ -146,61 +156,53 @@ def compute_g_loss(nets, loss_asr = F.smooth_l1_loss(ASR_fake, ASR_real) # style reconstruction loss - s_pred = nets.style_encoder(x_fake, y_trg) + s_pred = nets['style_encoder'](x_fake, y_trg) loss_sty = paddle.mean(paddle.abs(s_pred - s_trg)) # diversity sensitive loss if z_trgs is not None: - s_trg2 = nets.mapping_network(z_trg2, y_trg) + s_trg2 = nets['mapping_network'](z_trg2, y_trg) else: - s_trg2 = nets.style_encoder(x_ref2, y_trg) - x_fake2 = nets.generator(x_real, s_trg2, masks=None, F0=GAN_F0_real) + s_trg2 = nets['style_encoder'](x_ref2, y_trg) + x_fake2 = nets['generator'](x_real, s_trg2, masks=None, F0=GAN_F0_real) x_fake2 = x_fake2.detach() - _, GAN_F0_fake2, _ = nets.f0_model(x_fake2) + _, GAN_F0_fake2, _ = nets['F0_model'](x_fake2) loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2)) loss_ds += F.smooth_l1_loss(GAN_F0_fake, GAN_F0_fake2.detach()) # cycle-consistency loss - s_org = nets.style_encoder(x_real, y_org) - x_rec = nets.generator(x_fake, s_org, masks=None, F0=GAN_F0_fake) + s_org = nets['style_encoder'](x_real, y_org) + x_rec = nets['generator'](x_fake, s_org, masks=None, F0=GAN_F0_fake) loss_cyc = paddle.mean(paddle.abs(x_rec - x_real)) # F0 loss in cycle-consistency loss - if args.lambda_f0 > 0: - _, _, cyc_F0_rec = nets.f0_model(x_rec) + if lambda_f0 > 0: + _, _, cyc_F0_rec = nets['F0_model'](x_rec) loss_cyc += F.smooth_l1_loss(cyc_F0_rec, cyc_F0_real) - if args.lambda_asr > 0: - ASR_recon = nets.asr_model.get_feature(x_rec) + if lambda_asr > 0: + ASR_recon = nets['asr_model'].get_feature(x_rec) loss_cyc += F.smooth_l1_loss(ASR_recon, ASR_real) # adversarial classifier loss if use_adv_cls: - out_de = nets.discriminator.classifier(x_fake) + out_de = nets['discriminator'].classifier(x_fake) loss_adv_cls = F.cross_entropy(out_de[y_org != y_trg], y_trg[y_org != y_trg]) else: loss_adv_cls = paddle.zeros([1]).mean() - loss = args.lambda_adv * loss_adv + args.lambda_sty * loss_sty \ - - args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc\ - + args.lambda_norm * loss_norm \ - + args.lambda_asr * loss_asr \ - + args.lambda_f0 * loss_f0 \ - + args.lambda_f0_sty * loss_f0_sty \ - + args.lambda_adv_cls * loss_adv_cls - - return loss, Munch( - adv=loss_adv.item(), - sty=loss_sty.item(), - ds=loss_ds.item(), - cyc=loss_cyc.item(), - norm=loss_norm.item(), - asr=loss_asr.item(), - f0=loss_f0.item(), - adv_cls=loss_adv_cls.item()) + loss = lambda_adv * loss_adv + lambda_sty * loss_sty \ + - lambda_ds * loss_ds + lambda_cyc * loss_cyc \ + + lambda_norm * loss_norm \ + + lambda_asr * loss_asr \ + + lambda_f0 * loss_f0 \ + + lambda_f0_sty * loss_f0_sty \ + + lambda_adv_cls * loss_adv_cls + + return loss # for norm consistency loss -def log_norm(x, mean=-4, std=4, axis=2): +def log_norm(x: paddle.Tensor, mean: float=-4, std: float=4, axis: int=2): """ normalized log mel -> mel -> norm -> log(norm) """ @@ -209,7 +211,7 @@ def log_norm(x, mean=-4, std=4, axis=2): # for adversarial loss -def adv_loss(logits, target): +def adv_loss(logits: paddle.Tensor, target: float): assert target in [1, 0] if len(logits.shape) > 1: logits = logits.reshape([-1]) @@ -220,7 +222,7 @@ def adv_loss(logits, target): # for R1 regularization loss -def r1_reg(d_out, x_in): +def r1_reg(d_out: paddle.Tensor, x_in: paddle.Tensor): # zero-centered gradient penalty for real images batch_size = x_in.shape[0] grad_dout = paddle.grad( @@ -236,14 +238,14 @@ def r1_reg(d_out, x_in): # for F0 consistency loss -def compute_mean_f0(f0): +def compute_mean_f0(f0: paddle.Tensor): f0_mean = f0.mean(-1) f0_mean = f0_mean.expand((f0.shape[-1], f0_mean.shape[0])).transpose( (1, 0)) # (B, M) return f0_mean -def f0_loss(x_f0, y_f0): +def f0_loss(x_f0: paddle.Tensor, y_f0: paddle.Tensor): """ x.shape = (B, 1, M, L): predict y.shape = (B, 1, M, L): target diff --git a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py index 2a96b30c..99aeb73b 100644 --- a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py +++ b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py @@ -25,6 +25,8 @@ import paddle import paddle.nn.functional as F from paddle import nn +from paddlespeech.t2s.modules.nets_utils import _reset_parameters + class DownSample(nn.Layer): def __init__(self, layer_type: str): @@ -355,6 +357,8 @@ class Generator(nn.Layer): if w_hpf > 0: self.hpf = HighPass(w_hpf) + self.reset_parameters() + def forward(self, x: paddle.Tensor, s: paddle.Tensor, @@ -399,6 +403,9 @@ class Generator(nn.Layer): out = self.to_out(x) return out + def reset_parameters(self): + self.apply(_reset_parameters) + class MappingNetwork(nn.Layer): def __init__(self, @@ -427,18 +434,19 @@ class MappingNetwork(nn.Layer): nn.ReLU(), nn.Linear(hidden_dim, style_dim)) ]) + self.reset_parameters() + def forward(self, z: paddle.Tensor, y: paddle.Tensor): """Calculate forward propagation. Args: z(Tensor(float32)): - Shape (B, 1, n_mels, T). + Shape (B, latent_dim). y(Tensor(float32)): speaker label. Shape (B, ). Returns: Tensor: Shape (style_dim, ) """ - h = self.shared(z) out = [] for layer in self.unshared: @@ -450,6 +458,9 @@ class MappingNetwork(nn.Layer): s = out[idx, y] return s + def reset_parameters(self): + self.apply(_reset_parameters) + class StyleEncoder(nn.Layer): def __init__(self, @@ -491,6 +502,8 @@ class StyleEncoder(nn.Layer): for _ in range(num_domains): self.unshared.append(nn.Linear(dim_out, style_dim)) + self.reset_parameters() + def forward(self, x: paddle.Tensor, y: paddle.Tensor): """Calculate forward propagation. Args: @@ -514,6 +527,9 @@ class StyleEncoder(nn.Layer): s = out[idx, y] return s + def reset_parameters(self): + self.apply(_reset_parameters) + class Discriminator(nn.Layer): def __init__(self, @@ -536,7 +552,19 @@ class Discriminator(nn.Layer): repeat_num=repeat_num) self.num_domains = num_domains + self.reset_parameters() + def forward(self, x: paddle.Tensor, y: paddle.Tensor): + """Calculate forward propagation. + Args: + x(Tensor(float32)): + Shape (B, 1, 80, T). + y(Tensor(float32)): + Shape (B, ). + Returns: + Tensor: + Shape (B, ) + """ out = self.dis(x, y) return out @@ -544,6 +572,9 @@ class Discriminator(nn.Layer): out = self.cls.get_feature(x) return out + def reset_parameters(self): + self.apply(_reset_parameters) + class Discriminator2D(nn.Layer): def __init__(self, diff --git a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py index 595add0a..1b811a3f 100644 --- a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py +++ b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py @@ -11,3 +11,298 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from typing import Any +from typing import Dict + +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler + +from paddlespeech.t2s.models.starganv2_vc.losses import compute_d_loss +from paddlespeech.t2s.models.starganv2_vc.losses import compute_g_loss +from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator +from paddlespeech.t2s.training.reporter import report +from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater +from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState + +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class StarGANv2VCUpdater(StandardUpdater): + def __init__(self, + models: Dict[str, Layer], + optimizers: Dict[str, Optimizer], + schedulers: Dict[str, LRScheduler], + dataloader: DataLoader, + g_loss_params: Dict[str, Any]={ + 'lambda_sty': 1., + 'lambda_cyc': 5., + 'lambda_ds': 1., + 'lambda_norm': 1., + 'lambda_asr': 10., + 'lambda_f0': 5., + 'lambda_f0_sty': 0.1, + 'lambda_adv': 2., + 'lambda_adv_cls': 0.5, + 'norm_bias': 0.5, + }, + d_loss_params: Dict[str, Any]={ + 'lambda_reg': 1., + 'lambda_adv_cls': 0.1, + 'lambda_con_reg': 10., + }, + adv_cls_epoch: int=50, + con_reg_epoch: int=30, + use_r1_reg: bool=False, + output_dir=None): + self.models = models + + self.optimizers = optimizers + self.optimizer_g = optimizers['generator'] + self.optimizer_s = optimizers['style_encoder'] + self.optimizer_m = optimizers['mapping_network'] + self.optimizer_d = optimizers['discriminator'] + + self.schedulers = schedulers + self.scheduler_g = schedulers['generator'] + self.scheduler_s = schedulers['style_encoder'] + self.scheduler_m = schedulers['mapping_network'] + self.scheduler_d = schedulers['discriminator'] + + self.dataloader = dataloader + + self.g_loss_params = g_loss_params + self.d_loss_params = d_loss_params + + self.use_r1_reg = use_r1_reg + self.con_reg_epoch = con_reg_epoch + self.adv_cls_epoch = adv_cls_epoch + + self.state = UpdaterState(iteration=0, epoch=0) + self.train_iterator = iter(self.dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def zero_grad(self): + self.optimizer_d.clear_grad() + self.optimizer_g.clear_grad() + self.optimizer_m.clear_grad() + self.optimizer_s.clear_grad() + + def scheduler(self): + self.scheduler_d.step() + self.scheduler_g.step() + self.scheduler_m.step() + self.scheduler_s.step() + + def update_core(self, batch): + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + # parse batch + x_real = batch['x_real'] + y_org = batch['y_org'] + x_ref = batch['x_ref'] + x_ref2 = batch['x_ref2'] + y_trg = batch['y_trg'] + z_trg = batch['z_trg'] + z_trg2 = batch['z_trg2'] + + use_con_reg = (self.state.epoch >= self.con_reg_epoch) + use_adv_cls = (self.state.epoch >= self.adv_cls_epoch) + + # Discriminator loss + # train the discriminator (by random reference) + self.zero_grad() + random_d_loss = compute_d_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + z_trg=z_trg, + use_adv_cls=use_adv_cls, + use_con_reg=use_con_reg, + **self.d_loss_params) + random_d_loss.backward() + self.optimizer_d.step() + # train the discriminator (by target reference) + self.zero_grad() + target_d_loss = compute_d_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + x_ref=x_ref, + use_adv_cls=use_adv_cls, + use_con_reg=use_con_reg, + **self.d_loss_params) + target_d_loss.backward() + self.optimizer_d.step() + report("train/random_d_loss", float(random_d_loss)) + report("train/target_d_loss", float(target_d_loss)) + losses_dict["random_d_loss"] = float(random_d_loss) + losses_dict["target_d_loss"] = float(target_d_loss) + + # Generator + # train the generator (by random reference) + self.zero_grad() + random_g_loss = compute_g_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + z_trgs=[z_trg, z_trg2], + use_adv_cls=use_adv_cls, + **self.g_loss_params) + random_g_loss.backward() + self.optimizer_g.step() + self.optimizer_m.step() + self.optimizer_s.step() + + # train the generator (by target reference) + self.zero_grad() + target_g_loss = compute_g_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + x_refs=[x_ref, x_ref2], + use_adv_cls=use_adv_cls, + **self.g_loss_params) + target_g_loss.backward() + # 此处是否要 optimizer_g optimizer_m optimizer_s 都写上? + # 源码没写上后两个是否是疏忽? + self.optimizer_g.step() + # self.optimizer_m.step() + # self.optimizer_s.step() + report("train/random_g_loss", float(random_g_loss)) + report("train/target_g_loss", float(target_g_loss)) + losses_dict["random_g_loss"] = float(random_g_loss) + losses_dict["target_g_loss"] = float(target_g_loss) + + self.scheduler() + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + +class StarGANv2VCEvaluator(StandardEvaluator): + def __init__(self, + models: Dict[str, Layer], + dataloader: DataLoader, + g_loss_params: Dict[str, Any]={ + 'lambda_sty': 1., + 'lambda_cyc': 5., + 'lambda_ds': 1., + 'lambda_norm': 1., + 'lambda_asr': 10., + 'lambda_f0': 5., + 'lambda_f0_sty': 0.1, + 'lambda_adv': 2., + 'lambda_adv_cls': 0.5, + 'norm_bias': 0.5, + }, + d_loss_params: Dict[str, Any]={ + 'lambda_reg': 1., + 'lambda_adv_cls': 0.1, + 'lambda_con_reg': 10., + }, + adv_cls_epoch: int=50, + con_reg_epoch: int=30, + use_r1_reg: bool=False, + output_dir=None): + self.models = models + + self.dataloader = dataloader + + self.g_loss_params = g_loss_params + self.d_loss_params = d_loss_params + + self.use_r1_reg = use_r1_reg + self.con_reg_epoch = con_reg_epoch + self.adv_cls_epoch = adv_cls_epoch + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def evaluate_core(self, batch): + # logging.debug("Evaluate: ") + self.msg = "Evaluate: " + losses_dict = {} + + x_real = batch['x_real'] + y_org = batch['y_org'] + x_ref = batch['x_ref'] + x_ref2 = batch['x_ref2'] + y_trg = batch['y_trg'] + z_trg = batch['z_trg'] + z_trg2 = batch['z_trg2'] + + # eval the discriminator + + random_d_loss = compute_d_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + z_trg=z_trg, + use_r1_reg=self.use_r1_reg, + use_adv_cls=use_adv_cls, + **self.d_loss_params) + + target_d_loss = compute_d_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + x_ref=x_ref, + use_r1_reg=self.use_r1_reg, + use_adv_cls=use_adv_cls, + **self.d_loss_params) + + report("eval/random_d_loss", float(random_d_loss)) + report("eval/target_d_loss", float(target_d_loss)) + losses_dict["random_d_loss"] = float(random_d_loss) + losses_dict["target_d_loss"] = float(target_d_loss) + + # eval the generator + + random_g_loss = compute_g_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + z_trgs=[z_trg, z_trg2], + use_adv_cls=use_adv_cls, + **self.g_loss_params) + + target_g_loss = compute_g_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + x_refs=[x_ref, x_ref2], + use_adv_cls=use_adv_cls, + **self.g_loss_params) + + report("eval/random_g_loss", float(random_g_loss)) + report("eval/target_g_loss", float(target_g_loss)) + losses_dict["random_g_loss"] = float(random_g_loss) + losses_dict["target_g_loss"] = float(target_g_loss) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) diff --git a/paddlespeech/t2s/models/starganv2_vc/transforms.py b/paddlespeech/t2s/models/starganv2_vc/transforms.py new file mode 100644 index 00000000..d7586147 --- /dev/null +++ b/paddlespeech/t2s/models/starganv2_vc/transforms.py @@ -0,0 +1,143 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +import numpy as np +import paddle +import paddle.nn.functional as F +from paddle import nn + + +## 1. RandomTimeStrech +class TimeStrech(nn.Layer): + def __init__(self, scale): + super().__init__() + self.scale = scale + + def forward(self, x: paddle.Tensor): + mel_size = x.shape[-1] + + x = F.interpolate( + x, + scale_factor=(1, self.scale), + align_corners=False, + mode='bilinear').squeeze() + + if x.shape[-1] < mel_size: + noise_length = (mel_size - x.shape[-1]) + random_pos = random.randint(0, x.shape[-1]) - noise_length + if random_pos < 0: + random_pos = 0 + noise = x[..., random_pos:random_pos + noise_length] + x = paddle.concat([x, noise], axis=-1) + else: + x = x[..., :mel_size] + + return x.unsqueeze(1) + + +## 2. PitchShift +class PitchShift(nn.Layer): + def __init__(self, shift): + super().__init__() + self.shift = shift + + def forward(self, x: paddle.Tensor): + if len(x.shape) == 2: + x = x.unsqueeze(0) + x = x.squeeze() + mel_size = x.shape[1] + shift_scale = (mel_size + self.shift) / mel_size + x = F.interpolate( + x.unsqueeze(1), + scale_factor=(shift_scale, 1.), + align_corners=False, + mode='bilinear').squeeze(1) + + x = x[:, :mel_size] + if x.shape[1] < mel_size: + pad_size = mel_size - x.shape[1] + x = paddle.cat( + [x, paddle.zeros(x.shape[0], pad_size, x.shape[2])], axis=1) + x = x.squeeze() + return x.unsqueeze(1) + + +## 3. ShiftBias +class ShiftBias(nn.Layer): + def __init__(self, bias): + super().__init__() + self.bias = bias + + def forward(self, x: paddle.Tensor): + return x + self.bias + + +## 4. Scaling +class SpectScaling(nn.Layer): + def __init__(self, scale): + super().__init__() + self.scale = scale + + def forward(self, x: paddle.Tensor): + return x * self.scale + + +## 5. Time Flip +class TimeFlip(nn.Layer): + def __init__(self, length): + super().__init__() + self.length = round(length) + + def forward(self, x: paddle.Tensor): + if self.length > 1: + start = np.random.randint(0, x.shape[-1] - self.length) + x_ret = x.clone() + x_ret[..., start:start + self.length] = paddle.flip( + x[..., start:start + self.length], axis=[-1]) + x = x_ret + return x + + +class PhaseShuffle2D(nn.Layer): + def __init__(self, n: int=2): + super().__init__() + self.n = n + self.random = random.Random(1) + + def forward(self, x: paddle.Tensor, move=None): + # x.size = (B, C, M, L) + if move is None: + move = self.random.randint(-self.n, self.n) + + if move == 0: + return x + else: + left = x[:, :, :, :move] + right = x[:, :, :, move:] + shuffled = paddle.concat([right, left], axis=3) + + return shuffled + + +def build_transforms(): + transforms = [ + lambda M: TimeStrech(1 + (np.random.random() - 0.5) * M * 0.2), + lambda M: SpectScaling(1 + (np.random.random() - 1) * M * 0.1), + lambda M: PhaseShuffle2D(192), + ] + N, M = len(transforms), np.random.random() + composed = nn.Sequential( + * [trans(M) for trans in np.random.choice(transforms, N)]) + return composed diff --git a/paddlespeech/t2s/models/waveflow.py b/paddlespeech/t2s/models/waveflow.py index 8e2ce822..b4818cab 100644 --- a/paddlespeech/t2s/models/waveflow.py +++ b/paddlespeech/t2s/models/waveflow.py @@ -236,7 +236,7 @@ class ResidualBlock(nn.Layer): Returns: res (Tensor): - A row of the the residual output. shape=(batch_size, channel, 1, width) + A row of the residual output. shape=(batch_size, channel, 1, width) skip (Tensor): A row of the skip output. shape=(batch_size, channel, 1, width) @@ -343,7 +343,7 @@ class ResidualNet(nn.LayerList): Returns: res (Tensor): - A row of the the residual output. shape=(batch_size, channel, 1, width) + A row of the residual output. shape=(batch_size, channel, 1, width) skip (Tensor): A row of the skip output. shape=(batch_size, channel, 1, width) @@ -465,7 +465,7 @@ class Flow(nn.Layer): self.resnet.start_sequence() def inverse(self, z, condition): - """Sampling from the the distrition p(X). It is done by sample form + """Sampling from the distrition p(X). It is done by sample form p(Z) and transform the sample. It is a auto regressive transformation. Args: @@ -600,7 +600,7 @@ class WaveFlow(nn.LayerList): return z, log_det_jacobian def inverse(self, z, condition): - """Sampling from the the distrition p(X). + """Sampling from the distrition p(X). It is done by sample a ``z`` form p(Z) and transform it into ``x``. Each Flow transform .. math:: `z_{i-1}` to .. math:: `z_{i}` in an diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index 1a43f5ef..b4d78364 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from typing import Tuple import librosa import numpy as np @@ -19,8 +20,13 @@ import paddle from paddle import nn from paddle.nn import functional as F from scipy import signal +from scipy.stats import betabinom +from typeguard import check_argument_types from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask +from paddlespeech.t2s.modules.predictor.duration_predictor import ( + DurationPredictorLoss, # noqa: H301 +) # Losses for WaveRNN @@ -1126,3 +1132,195 @@ class MLMLoss(nn.Layer): text_masked_pos_reshape) / paddle.sum((text_masked_pos) + 1e-10) return mlm_loss, text_mlm_loss + + +class VarianceLoss(nn.Layer): + def __init__(self, use_masking: bool=True, + use_weighted_masking: bool=False): + """Initialize JETS variance loss module. + Args: + use_masking (bool): Whether to apply masking for padded part in loss + calculation. + use_weighted_masking (bool): Whether to weighted masking in loss + calculation. + + """ + assert check_argument_types() + super().__init__() + + assert (use_masking != use_weighted_masking) or not use_masking + self.use_masking = use_masking + self.use_weighted_masking = use_weighted_masking + + # define criterions + reduction = "none" if self.use_weighted_masking else "mean" + self.mse_criterion = nn.MSELoss(reduction=reduction) + self.duration_criterion = DurationPredictorLoss(reduction=reduction) + + def forward( + self, + d_outs: paddle.Tensor, + ds: paddle.Tensor, + p_outs: paddle.Tensor, + ps: paddle.Tensor, + e_outs: paddle.Tensor, + es: paddle.Tensor, + ilens: paddle.Tensor, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Calculate forward propagation. + + Args: + d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text). + ds (LongTensor): Batch of durations (B, T_text). + p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1). + ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1). + e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1). + es (Tensor): Batch of target token-averaged energy (B, T_text, 1). + ilens (LongTensor): Batch of the lengths of each input (B,). + + Returns: + Tensor: Duration predictor loss value. + Tensor: Pitch predictor loss value. + Tensor: Energy predictor loss value. + + """ + # apply mask to remove padded part + if self.use_masking: + duration_masks = paddle.to_tensor( + make_non_pad_mask(ilens), place=ds.place) + d_outs = d_outs.masked_select(duration_masks) + ds = ds.masked_select(duration_masks) + pitch_masks = paddle.to_tensor( + make_non_pad_mask(ilens).unsqueeze(-1), place=ds.place) + p_outs = p_outs.masked_select(pitch_masks) + e_outs = e_outs.masked_select(pitch_masks) + ps = ps.masked_select(pitch_masks) + es = es.masked_select(pitch_masks) + + # calculate loss + duration_loss = self.duration_criterion(d_outs, ds) + pitch_loss = self.mse_criterion(p_outs, ps) + energy_loss = self.mse_criterion(e_outs, es) + + # make weighted mask and apply it + if self.use_weighted_masking: + duration_masks = paddle.to_tensor( + make_non_pad_mask(ilens), place=ds.place) + duration_weights = (duration_masks.float() / + duration_masks.sum(dim=1, keepdim=True).float()) + duration_weights /= ds.size(0) + + # apply weight + duration_loss = (duration_loss.mul(duration_weights).masked_select( + duration_masks).sum()) + pitch_masks = duration_masks.unsqueeze(-1) + pitch_weights = duration_weights.unsqueeze(-1) + pitch_loss = pitch_loss.mul(pitch_weights).masked_select( + pitch_masks).sum() + energy_loss = ( + energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum()) + + return duration_loss, pitch_loss, energy_loss + + +class ForwardSumLoss(nn.Layer): + """ + https://openreview.net/forum?id=0NQwnnwAORi + """ + + def __init__(self, cache_prior: bool=True): + """ + Args: + cache_prior (bool): Whether to cache beta-binomial prior + """ + super().__init__() + self.cache_prior = cache_prior + self._cache = {} + + def forward( + self, + log_p_attn: paddle.Tensor, + ilens: paddle.Tensor, + olens: paddle.Tensor, + blank_prob: float=np.e**-1, ) -> paddle.Tensor: + """ + Args: + log_p_attn (Tensor): Batch of log probability of attention matrix (B, T_feats, T_text). + ilens (Tensor): Batch of the lengths of each input (B,). + olens (Tensor): Batch of the lengths of each target (B,). + blank_prob (float): Blank symbol probability + + Returns: + Tensor: forwardsum loss value. + """ + + B = log_p_attn.shape[0] + # add beta-binomial prior + bb_prior = self._generate_prior(ilens, olens) + bb_prior = paddle.to_tensor( + bb_prior, dtype=log_p_attn.dtype, place=log_p_attn.place) + log_p_attn = log_p_attn + bb_prior + + # a row must be added to the attention matrix to account for blank token of CTC loss + # (B,T_feats,T_text+1) + log_p_attn_pd = F.pad( + log_p_attn, (0, 0, 0, 0, 1, 0), value=np.log(blank_prob)) + loss = 0 + for bidx in range(B): + # construct target sequnece. + # Every text token is mapped to a unique sequnece number. + target_seq = paddle.arange( + 1, ilens[bidx] + 1, dtype="int32").unsqueeze(0) + cur_log_p_attn_pd = log_p_attn_pd[bidx, :olens[bidx], :ilens[ + bidx] + 1].unsqueeze(1) # (T_feats,1,T_text+1) + # The input of ctc_loss API need to be fixed + loss += F.ctc_loss( + log_probs=cur_log_p_attn_pd, + labels=target_seq, + input_lengths=olens[bidx:bidx + 1], + label_lengths=ilens[bidx:bidx + 1]) + loss = loss / B + + return loss + + def _generate_prior(self, text_lengths, feats_lengths, + w=1) -> paddle.Tensor: + """Generate alignment prior formulated as beta-binomial distribution + + Args: + text_lengths (Tensor): Batch of the lengths of each input (B,). + feats_lengths (Tensor): Batch of the lengths of each target (B,). + w (float): Scaling factor; lower -> wider the width + + Returns: + Tensor: Batched 2d static prior matrix (B, T_feats, T_text) + """ + B = len(text_lengths) + T_text = text_lengths.max() + T_feats = feats_lengths.max() + + bb_prior = paddle.full((B, T_feats, T_text), fill_value=-np.inf) + for bidx in range(B): + T = feats_lengths[bidx].item() + N = text_lengths[bidx].item() + + key = str(T) + ',' + str(N) + if self.cache_prior and key in self._cache: + prob = self._cache[key] + else: + alpha = w * np.arange(1, T + 1, dtype=float) # (T,) + beta = w * np.array([T - t + 1 for t in alpha]) + k = np.arange(N) + batched_k = k[..., None] # (N,1) + prob = betabinom.pmf(batched_k, N, alpha, beta) # (N,T) + + # store cache + if self.cache_prior and key not in self._cache: + self._cache[key] = prob + + prob = paddle.to_tensor( + prob, place=text_lengths.place, dtype="float32").transpose( + (1, 0)) # -> (T,N) + bb_prior[bidx, :T, :N] = prob + + return bb_prior diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index 99130acc..3d1b48de 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -20,6 +20,44 @@ import paddle from paddle import nn from typeguard import check_argument_types +from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out +from paddlespeech.utils.initialize import kaiming_uniform_ +from paddlespeech.utils.initialize import normal_ +from paddlespeech.utils.initialize import ones_ +from paddlespeech.utils.initialize import uniform_ +from paddlespeech.utils.initialize import zeros_ + + +# default init method of torch +# copy from https://github.com/PaddlePaddle/PaddleSpeech/blob/9cf8c1985a98bb380c183116123672976bdfe5c9/paddlespeech/t2s/models/vits/vits.py#L506 +def _reset_parameters(module): + if isinstance(module, (nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, + nn.Conv2DTranspose)): + kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + uniform_(module.bias, -bound, bound) + + if isinstance(module, + (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)): + ones_(module.weight) + zeros_(module.bias) + + if isinstance(module, nn.Linear): + kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + uniform_(module.bias, -bound, bound) + + if isinstance(module, nn.Embedding): + normal_(module.weight) + if module._padding_idx is not None: + with paddle.no_grad(): + module.weight[module._padding_idx] = 0 + def pad_list(xs, pad_value): """Perform padding for the list of tensors. diff --git a/paddlespeech/t2s/modules/transformer/lightconv.py b/paddlespeech/t2s/modules/transformer/lightconv.py index 22217d50..85336f4f 100644 --- a/paddlespeech/t2s/modules/transformer/lightconv.py +++ b/paddlespeech/t2s/modules/transformer/lightconv.py @@ -110,7 +110,7 @@ class LightweightConvolution(nn.Layer): (batch, time1, time2) mask Return: - Tensor: ouput. (batch, time1, d_model) + Tensor: output. (batch, time1, d_model) """ # linear -> GLU -> lightconv -> linear diff --git a/paddlespeech/utils/argparse.py b/paddlespeech/utils/argparse.py new file mode 100644 index 00000000..aad3801e --- /dev/null +++ b/paddlespeech/utils/argparse.py @@ -0,0 +1,100 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hashlib +import os +import sys +from typing import Text + +import distutils + +__all__ = ["print_arguments", "add_arguments", "get_commandline_args"] + + +def get_commandline_args(): + extra_chars = [ + " ", + ";", + "&", + "(", + ")", + "|", + "^", + "<", + ">", + "?", + "*", + "[", + "]", + "$", + "`", + '"', + "\\", + "!", + "{", + "}", + ] + + # Escape the extra characters for shell + argv = [ + arg.replace("'", "'\\''") if all(char not in arg + for char in extra_chars) else + "'" + arg.replace("'", "'\\''") + "'" for arg in sys.argv + ] + + return sys.executable + " " + " ".join(argv) + + +def print_arguments(args, info=None): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + filename = "" + if info: + filename = info["__file__"] + filename = os.path.basename(filename) + print(f"----------- {filename} Configuration Arguments -----------") + for arg, value in sorted(vars(args).items()): + print("%s: %s" % (arg, value)) + print("-----------------------------------------------------------") + + +def add_arguments(argname, type, default, help, argparser, **kwargs): + """Add argparse's argument. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + add_argument("name", str, "Jonh", "User name.", parser) + args = parser.parse_args() + """ + type = distutils.util.strtobool if type == bool else type + argparser.add_argument( + "--" + argname, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) diff --git a/paddlespeech/vector/exps/ecapa_tdnn/train.py b/paddlespeech/vector/exps/ecapa_tdnn/train.py index bf014045..2dc7a716 100644 --- a/paddlespeech/vector/exps/ecapa_tdnn/train.py +++ b/paddlespeech/vector/exps/ecapa_tdnn/train.py @@ -51,7 +51,7 @@ def main(args, config): # stage0: set the training device, cpu or gpu paddle.set_device(args.device) - # stage1: we must call the paddle.distributed.init_parallel_env() api at the begining + # stage1: we must call the paddle.distributed.init_parallel_env() api at the beginning paddle.distributed.init_parallel_env() nranks = paddle.distributed.get_world_size() rank = paddle.distributed.get_rank() @@ -146,7 +146,7 @@ def main(args, config): timer.start() for epoch in range(start_epoch + 1, config.epochs + 1): - # at the begining, model must set to train mode + # at the beginning, model must set to train mode model.train() avg_loss = 0 diff --git a/paddlespeech/vector/exps/ge2e/preprocess.py b/paddlespeech/vector/exps/ge2e/preprocess.py index dabe0ce7..ee59e624 100644 --- a/paddlespeech/vector/exps/ge2e/preprocess.py +++ b/paddlespeech/vector/exps/ge2e/preprocess.py @@ -42,7 +42,7 @@ if __name__ == "__main__": parser.add_argument( "--skip_existing", action="store_true", - help="Whether to skip ouput files with the same name. Useful if this script was interrupted." + help="Whether to skip output files with the same name. Useful if this script was interrupted." ) parser.add_argument( "--no_trim", diff --git a/speechx/.clang-format b/runtime/.clang-format similarity index 100% rename from speechx/.clang-format rename to runtime/.clang-format diff --git a/runtime/.gitignore b/runtime/.gitignore new file mode 100644 index 00000000..a654dae4 --- /dev/null +++ b/runtime/.gitignore @@ -0,0 +1,7 @@ +engine/common/base/flags.h +engine/common/base/log.h + +tools/valgrind* +*log +fc_patch/* +test diff --git a/runtime/CMakeLists.txt b/runtime/CMakeLists.txt new file mode 100644 index 00000000..092c8b25 --- /dev/null +++ b/runtime/CMakeLists.txt @@ -0,0 +1,211 @@ +# >=3.17 support -DCMAKE_FIND_DEBUG_MODE=ON +cmake_minimum_required(VERSION 3.17 FATAL_ERROR) + +set(CMAKE_PROJECT_INCLUDE_BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/EnableCMP0077.cmake") + +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + +include(system) + +project(paddlespeech VERSION 0.1) + +set(PPS_VERSION_MAJOR 1) +set(PPS_VERSION_MINOR 0) +set(PPS_VERSION_PATCH 0) +set(PPS_VERSION "${PPS_VERSION_MAJOR}.${PPS_VERSION_MINOR}.${PPS_VERSION_PATCH}") + +# compiler option +# Keep the same with openfst, -fPIC or -fpic +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl") +SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb") +SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall") + +set(CMAKE_VERBOSE_MAKEFILE ON) +set(CMAKE_FIND_DEBUG_MODE OFF) +set(PPS_CXX_STANDARD 14) + +# set std-14 +set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD}) + +# Ninja Generator will set CMAKE_BUILD_TYPE to Debug +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel" FORCE) +endif() + +# find_* e.g. find_library work when Cross-Compiling +if(ANDROID) + set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE BOTH) + set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH) +endif() + +if(BUILD_IN_MACOS) + add_definitions("-DOS_MACOSX") +endif() + +# install dir into `build/install` +set(CMAKE_INSTALL_PREFIX ${CMAKE_CURRENT_BINARY_DIR}/install) + +include(FetchContent) +include(ExternalProject) + +# fc_patch dir +set(FETCHCONTENT_QUIET off) +get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}") +set(FETCHCONTENT_BASE_DIR ${fc_patch}) + +############################################################################### +# Option Configurations +############################################################################### +# https://github.com/google/brotli/pull/655 +option(BUILD_SHARED_LIBS "Build shared libraries" ON) + +option(WITH_PPS_DEBUG "debug option" OFF) +if (WITH_PPS_DEBUG) + add_definitions("-DPPS_DEBUG") +endif() + +option(WITH_ASR "build asr" ON) +option(WITH_CLS "build cls" ON) +option(WITH_VAD "build vad" ON) + +option(WITH_GPU "NNet using GPU." OFF) + +option(WITH_PROFILING "enable c++ profling" OFF) +option(WITH_TESTING "unit test" ON) + +option(WITH_ONNX "u2 support onnx runtime" OFF) + +############################################################################### +# Include Third Party +############################################################################### +include(gflags) + +include(glog) + +include(pybind) + +#onnx +if(WITH_ONNX) + add_definitions(-DUSE_ONNX) +endif() + +# gtest +if(WITH_TESTING) + include(gtest) # download, build, install gtest +endif() + +# fastdeploy +include(fastdeploy) + +if(WITH_ASR) + # openfst + include(openfst) + add_dependencies(openfst gflags extern_glog) +endif() + +############################################################################### +# Find Package +############################################################################### +# https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207 +find_package(Threads REQUIRED) + +if(WITH_ASR) + # https://cmake.org/cmake/help/latest/module/FindPython3.html#module:FindPython3 + find_package(Python3 COMPONENTS Interpreter Development) + find_package(pybind11 CONFIG) + + if(Python3_FOUND) + message(STATUS "Python3_FOUND = ${Python3_FOUND}") + message(STATUS "Python3_EXECUTABLE = ${Python3_EXECUTABLE}") + message(STATUS "Python3_LIBRARIES = ${Python3_LIBRARIES}") + message(STATUS "Python3_INCLUDE_DIRS = ${Python3_INCLUDE_DIRS}") + message(STATUS "Python3_LINK_OPTIONS = ${Python3_LINK_OPTIONS}") + set(PYTHON_LIBRARIES ${Python3_LIBRARIES} CACHE STRING "python lib" FORCE) + set(PYTHON_INCLUDE_DIR ${Python3_INCLUDE_DIRS} CACHE STRING "python inc" FORCE) + endif() + + message(STATUS "PYTHON_LIBRARIES = ${PYTHON_LIBRARIES}") + message(STATUS "PYTHON_INCLUDE_DIR = ${PYTHON_INCLUDE_DIR}") + include_directories(${PYTHON_INCLUDE_DIR}) + + if(pybind11_FOUND) + message(STATUS "pybind11_INCLUDES = ${pybind11_INCLUDE_DIRS}") + message(STATUS "pybind11_LIBRARIES=${pybind11_LIBRARIES}") + message(STATUS "pybind11_DEFINITIONS=${pybind11_DEFINITIONS}") + endif() + + + # paddle libpaddle.so + # paddle include and link option + # -L/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/libs -L/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/paddle/fluid -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so + set(EXECUTE_COMMAND "import os" + "import paddle" + "include_dir = paddle.sysconfig.get_include()" + "paddle_dir=os.path.split(include_dir)[0]" + "libs_dir=os.path.join(paddle_dir, 'libs')" + "fluid_dir=os.path.join(paddle_dir, 'fluid')" + "out=' '.join([\"-L\" + libs_dir, \"-L\" + fluid_dir])" + "out += \" -l:libpaddle.so -l:libdnnl.so.2 -l:libiomp5.so\"; print(out)" + ) + execute_process( + COMMAND python -c "${EXECUTE_COMMAND}" + OUTPUT_VARIABLE PADDLE_LINK_FLAGS + RESULT_VARIABLE SUCESS) + + message(STATUS PADDLE_LINK_FLAGS= ${PADDLE_LINK_FLAGS}) + string(STRIP ${PADDLE_LINK_FLAGS} PADDLE_LINK_FLAGS) + + # paddle compile option + # -I/workspace/DeepSpeech-2.x/engine/venv/lib/python3.7/site-packages/paddle/include + set(EXECUTE_COMMAND "import paddle" + "include_dir = paddle.sysconfig.get_include()" + "print(f\"-I{include_dir}\")" + ) + execute_process( + COMMAND python -c "${EXECUTE_COMMAND}" + OUTPUT_VARIABLE PADDLE_COMPILE_FLAGS) + message(STATUS PADDLE_COMPILE_FLAGS= ${PADDLE_COMPILE_FLAGS}) + string(STRIP ${PADDLE_COMPILE_FLAGS} PADDLE_COMPILE_FLAGS) + + # for LD_LIBRARY_PATH + # set(PADDLE_LIB_DIRS /workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/fluid:/workspace/DeepSpeech-2.x/tools/venv/lib/python3.7/site-packages/paddle/libs/) + set(EXECUTE_COMMAND "import os" + "import paddle" + "include_dir=paddle.sysconfig.get_include()" + "paddle_dir=os.path.split(include_dir)[0]" + "libs_dir=os.path.join(paddle_dir, 'libs')" + "fluid_dir=os.path.join(paddle_dir, 'fluid')" + "out=':'.join([libs_dir, fluid_dir]); print(out)" + ) + execute_process( + COMMAND python -c "${EXECUTE_COMMAND}" + OUTPUT_VARIABLE PADDLE_LIB_DIRS) + message(STATUS PADDLE_LIB_DIRS= ${PADDLE_LIB_DIRS}) +endif() + +include(summary) + +############################################################################### +# Add local library +############################################################################### +set(ENGINE_ROOT ${CMAKE_SOURCE_DIR}/engine) + +add_subdirectory(engine) + + +############################################################################### +# CPack library +############################################################################### +# build a CPack driven installer package +include (InstallRequiredSystemLibraries) +set(CPACK_PACKAGE_NAME "paddlespeech_library") +set(CPACK_PACKAGE_VENDOR "paddlespeech") +set(CPACK_PACKAGE_VERSION_MAJOR 1) +set(CPACK_PACKAGE_VERSION_MINOR 0) +set(CPACK_PACKAGE_VERSION_PATCH 0) +set(CPACK_PACKAGE_DESCRIPTION "paddlespeech library") +set(CPACK_PACKAGE_CONTACT "paddlespeech@baidu.com") +set(CPACK_SOURCE_GENERATOR "TGZ") +include (CPack) diff --git a/speechx/README.md b/runtime/README.md similarity index 92% rename from speechx/README.md rename to runtime/README.md index 5d4b5845..40aa9444 100644 --- a/speechx/README.md +++ b/runtime/README.md @@ -1,4 +1,3 @@ -# SpeechX -- All in One Speech Task Inference ## Environment @@ -9,7 +8,7 @@ We develop under: * gcc/g++/gfortran - 8.2.0 * cmake - 3.16.0 -> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build speechx. +> Please use `tools/env.sh` to create python `venv`, then `source venv/bin/activate` to build engine. > We make sure all things work fun under docker, and recommend using it to develop and deploy. @@ -33,7 +32,7 @@ docker run --privileged --net=host --ipc=host -it --rm -v /path/to/paddlespeech bash tools/venv.sh ``` -2. Build `speechx` and `examples`. +2. Build `engine` and `examples`. For now we are using feature under `develop` branch of paddle, so we need to install `paddlepaddle` nightly build version. For example: @@ -113,3 +112,11 @@ apt-get install gfortran-8 4. `Undefined reference to '_gfortran_concat_string'` using gcc 8.2, gfortran 8.2. + +5. `./boost/python/detail/wrap_python.hpp:57:11: fatal error: pyconfig.h: No such file or directory` + +``` +apt-get install python3-dev +``` + +for more info please see [here](https://github.com/okfn/piati/issues/65). diff --git a/runtime/build.sh b/runtime/build.sh new file mode 100755 index 00000000..68889010 --- /dev/null +++ b/runtime/build.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +set -xe + +BUILD_ROOT=build/Linux +BUILD_DIR=${BUILD_ROOT}/x86_64 + +mkdir -p ${BUILD_DIR} + +BUILD_TYPE=Release +#BUILD_TYPE=Debug +BUILD_SO=OFF +BUILD_ONNX=ON +BUILD_ASR=ON +BUILD_CLS=ON +BUILD_VAD=ON +PPS_DEBUG=OFF +FASTDEPLOY_INSTALL_DIR="" + +# the build script had verified in the paddlepaddle docker image. +# please follow the instruction below to install PaddlePaddle image. +# https://www.paddlepaddle.org.cn/documentation/docs/zh/install/docker/linux-docker.html +#cmake -B build -DBUILD_SHARED_LIBS=OFF -DWITH_ASR=OFF -DWITH_CLS=OFF -DWITH_VAD=ON -DFASTDEPLOY_INSTALL_DIR=/workspace/zhanghui/paddle/FastDeploy/build/Android/arm64-v8a-api-21/install +cmake -B ${BUILD_DIR} \ + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DBUILD_SHARED_LIBS=${BUILD_SO} \ + -DWITH_ONNX=${BUILD_ONNX} \ + -DWITH_ASR=${BUILD_ASR} \ + -DWITH_CLS=${BUILD_CLS} \ + -DWITH_VAD=${BUILD_VAD} \ + -DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \ + -DWITH_PPS_DEBUG=${PPS_DEBUG} + +cmake --build ${BUILD_DIR} -j diff --git a/runtime/build_android.sh b/runtime/build_android.sh new file mode 100755 index 00000000..ce78e67c --- /dev/null +++ b/runtime/build_android.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +set -ex + +ANDROID_NDK=/mnt/masimeng/workspace/software/android-ndk-r25b/ + +# Setting up Android toolchanin +ANDROID_ABI=arm64-v8a # 'arm64-v8a', 'armeabi-v7a' +ANDROID_PLATFORM="android-21" # API >= 21 +ANDROID_STL=c++_shared # 'c++_shared', 'c++_static' +ANDROID_TOOLCHAIN=clang # 'clang' only +TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake + +# Create build directory +BUILD_ROOT=build/Android +BUILD_DIR=${BUILD_ROOT}/${ANDROID_ABI}-api-21 +FASTDEPLOY_INSTALL_DIR="/mnt/masimeng/workspace/FastDeploy/build/Android/arm64-v8a-api-21/install" + +mkdir -p ${BUILD_DIR} +cd ${BUILD_DIR} + +# CMake configuration with Android toolchain +cmake -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN_FILE} \ + -DCMAKE_BUILD_TYPE=MinSizeRel \ + -DANDROID_ABI=${ANDROID_ABI} \ + -DANDROID_NDK=${ANDROID_NDK} \ + -DANDROID_PLATFORM=${ANDROID_PLATFORM} \ + -DANDROID_STL=${ANDROID_STL} \ + -DANDROID_TOOLCHAIN=${ANDROID_TOOLCHAIN} \ + -DBUILD_SHARED_LIBS=OFF \ + -DWITH_ASR=OFF \ + -DWITH_CLS=OFF \ + -DWITH_VAD=ON \ + -DFASTDEPLOY_INSTALL_DIR=${FASTDEPLOY_INSTALL_DIR} \ + -DCMAKE_FIND_DEBUG_MODE=OFF \ + -Wno-dev ../../.. + +# Build FastDeploy Android C++ SDK +make diff --git a/runtime/build_ios.sh b/runtime/build_ios.sh new file mode 100644 index 00000000..74f76bf6 --- /dev/null +++ b/runtime/build_ios.sh @@ -0,0 +1,91 @@ +# https://www.jianshu.com/p/33672fb819f5 + +PATH="/Applications/CMake.app/Contents/bin":"$PATH" +tools_dir=$1 +ios_toolchain_cmake=${tools_dir}/"/ios-cmake-4.2.0/ios.toolchain.cmake" +fastdeploy_dir=${tools_dir}"/fastdeploy-ort-mac-build/" +build_targets=("OS64") +build_type_array=("Release") + +#static_name="libocr" +#lib_name="libocr" + +# Switch to workpath +current_path=`cd $(dirname $0);pwd` +work_path=${current_path}/ +build_path=${current_path}/build/ +output_path=${current_path}/output/ +cd ${work_path} + +# Clean +rm -rf ${build_path} +rm -rf ${output_path} + +if [ "$1"x = "clean"x ]; then + exit 0 +fi + +# Build Every Target +for target in "${build_targets[@]}" +do + for build_type in "${build_type_array[@]}" + do + echo -e "\033[1;36;40mBuilding ${build_type} ${target} ... \033[0m" + target_build_path=${build_path}/${target}/${build_type}/ + mkdir -p ${target_build_path} + + cd ${target_build_path} + if [ $? -ne 0 ];then + echo -e "\033[1;31;40mcd ${target_build_path} failed \033[0m" + exit -1 + fi + + if [ ${target} == "OS64" ];then + fastdeploy_install_dir=${fastdeploy_dir}/arm64 + else + fastdeploy_install_dir="" + echo "fastdeploy_install_dir is null" + exit -1 + fi + + cmake -DCMAKE_TOOLCHAIN_FILE=${ios_toolchain_cmake} \ + -DBUILD_IN_MACOS=ON \ + -DBUILD_SHARED_LIBS=OFF \ + -DWITH_ASR=OFF \ + -DWITH_CLS=OFF \ + -DWITH_VAD=ON \ + -DFASTDEPLOY_INSTALL_DIR=${fastdeploy_install_dir} \ + -DPLATFORM=${target} ../../../ + + cmake --build . --config ${build_type} + + mkdir output + cp engine/vad/interface/libpps_vad_interface.a output + cp engine/vad/interface/vad_interface_main.app/vad_interface_main output + cp ${fastdeploy_install_dir}/lib/libfastdeploy.dylib output + cp ${fastdeploy_install_dir}/third_libs/install/onnxruntime/lib/libonnxruntime.dylib output + + done +done + +## combine all ios libraries +#DEVROOT=/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/ +#LIPO_TOOL=${DEVROOT}/usr/bin/lipo +#LIBRARY_PATH=${build_path} +#LIBRARY_OUTPUT_PATH=${output_path}/IOS +#mkdir -p ${LIBRARY_OUTPUT_PATH} +# +#${LIPO_TOOL} \ +# -arch i386 ${LIBRARY_PATH}/ios_x86/Release/${lib_name}.a \ +# -arch x86_64 ${LIBRARY_PATH}/ios_x86_64/Release/${lib_name}.a \ +# -arch armv7 ${LIBRARY_PATH}/ios_armv7/Release/${lib_name}.a \ +# -arch armv7s ${LIBRARY_PATH}/ios_armv7s/Release/${lib_name}.a \ +# -arch arm64 ${LIBRARY_PATH}/ios_armv8/Release/${lib_name}.a \ +# -output ${LIBRARY_OUTPUT_PATH}/${lib_name}.a -create +# +#cp ${work_path}/lib/houyi/lib/ios/libhouyi_score.a ${LIBRARY_OUTPUT_PATH}/ +#cp ${work_path}/interface/ocr-interface.h ${output_path} +#cp ${work_path}/version/release.v ${output_path} +# +#echo -e "\033[1;36;40mBuild All Target Success At:\n${output_path}\033[0m" +#exit 0 diff --git a/speechx/cmake/EnableCMP0048.cmake b/runtime/cmake/EnableCMP0048.cmake similarity index 100% rename from speechx/cmake/EnableCMP0048.cmake rename to runtime/cmake/EnableCMP0048.cmake diff --git a/runtime/cmake/EnableCMP0077.cmake b/runtime/cmake/EnableCMP0077.cmake new file mode 100644 index 00000000..a7deaffb --- /dev/null +++ b/runtime/cmake/EnableCMP0077.cmake @@ -0,0 +1 @@ +cmake_policy(SET CMP0077 NEW) diff --git a/speechx/cmake/FindGFortranLibs.cmake b/runtime/cmake/FindGFortranLibs.cmake similarity index 100% rename from speechx/cmake/FindGFortranLibs.cmake rename to runtime/cmake/FindGFortranLibs.cmake diff --git a/speechx/cmake/absl.cmake b/runtime/cmake/absl.cmake similarity index 100% rename from speechx/cmake/absl.cmake rename to runtime/cmake/absl.cmake diff --git a/speechx/cmake/boost.cmake b/runtime/cmake/boost.cmake similarity index 100% rename from speechx/cmake/boost.cmake rename to runtime/cmake/boost.cmake diff --git a/speechx/cmake/eigen.cmake b/runtime/cmake/eigen.cmake similarity index 100% rename from speechx/cmake/eigen.cmake rename to runtime/cmake/eigen.cmake diff --git a/runtime/cmake/fastdeploy.cmake b/runtime/cmake/fastdeploy.cmake new file mode 100644 index 00000000..e095cd4c --- /dev/null +++ b/runtime/cmake/fastdeploy.cmake @@ -0,0 +1,116 @@ +include(FetchContent) + +set(EXTERNAL_PROJECT_LOG_ARGS + LOG_DOWNLOAD 1 # Wrap download in script to log output + LOG_UPDATE 1 # Wrap update in script to log output + LOG_PATCH 1 + LOG_CONFIGURE 1# Wrap configure in script to log output + LOG_BUILD 1 # Wrap build in script to log output + LOG_INSTALL 1 + LOG_TEST 1 # Wrap test in script to log output + LOG_MERGED_STDOUTERR 1 + LOG_OUTPUT_ON_FAILURE 1 +) + +if(NOT FASTDEPLOY_INSTALL_DIR) + if(ANDROID) + FetchContent_Declare( + fastdeploy + URL https://bj.bcebos.com/fastdeploy/release/android/fastdeploy-android-1.0.4-shared.tgz + URL_HASH MD5=2a15301158e9eb157a4f11283689e7ba + ${EXTERNAL_PROJECT_LOG_ARGS} + ) + add_definitions("-DUSE_PADDLE_LITE_BAKEND") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -g -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -g0 -O3 -mfloat-abi=softfp -mfpu=vfpv3 -mfpu=neon -fPIC -pie -fPIE") + else() # Linux + FetchContent_Declare( + fastdeploy + URL https://paddlespeech.bj.bcebos.com/speechx/fastdeploy/fastdeploy-1.0.5-x86_64-onnx.tar.gz + URL_HASH MD5=33900d986ea71aa78635e52f0733227c + ${EXTERNAL_PROJECT_LOG_ARGS} + ) + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -msse -msse2") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -msse -msse2 -mavx -O3") + endif() + + FetchContent_MakeAvailable(fastdeploy) + + set(FASTDEPLOY_INSTALL_DIR ${fc_patch}/fastdeploy-src) +endif() + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# fix compiler flags conflict, since fastdeploy using c++11 for project +# this line must after `include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)` +set(CMAKE_CXX_STANDARD ${PPS_CXX_STANDARD}) + +include_directories(${FASTDEPLOY_INCS}) + +# install fastdeploy and dependents lib +# install_fastdeploy_libraries(${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}) +# No dynamic libs need to install while using +# FastDeploy static lib. +if(ANDROID AND WITH_ANDROID_STATIC_LIB) + return() +endif() + +set(DYN_LIB_SUFFIX "*.so*") +if(WIN32) + set(DYN_LIB_SUFFIX "*.dll") +elseif(APPLE) + set(DYN_LIB_SUFFIX "*.dylib*") +endif() + +if(FastDeploy_DIR) + set(DYN_SEARCH_DIR ${FastDeploy_DIR}) +elseif(FASTDEPLOY_INSTALL_DIR) + set(DYN_SEARCH_DIR ${FASTDEPLOY_INSTALL_DIR}) +else() + message(FATAL_ERROR "Please set FastDeploy_DIR/FASTDEPLOY_INSTALL_DIR before call install_fastdeploy_libraries.") +endif() + +file(GLOB_RECURSE ALL_NEED_DYN_LIBS ${DYN_SEARCH_DIR}/lib/${DYN_LIB_SUFFIX}) +file(GLOB_RECURSE ALL_DEPS_DYN_LIBS ${DYN_SEARCH_DIR}/third_libs/${DYN_LIB_SUFFIX}) + +if(ENABLE_VISION) + # OpenCV + if(ANDROID) + file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${DYN_LIB_SUFFIX}) + else() + file(GLOB_RECURSE ALL_OPENCV_DYN_LIBS ${OpenCV_DIR}/../../${DYN_LIB_SUFFIX}) + endif() + + list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_OPENCV_DYN_LIBS}) + + if(WIN32) + file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/x64/vc15/bin/${DYN_LIB_SUFFIX}) + install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib) + elseif(ANDROID AND (NOT WITH_ANDROID_OPENCV_STATIC)) + file(GLOB OPENCV_DYN_LIBS ${OpenCV_NATIVE_DIR}/libs/${ANDROID_ABI}/${DYN_LIB_SUFFIX}) + install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib) + else() # linux/mac + file(GLOB OPENCV_DYN_LIBS ${OpenCV_DIR}/lib/${DYN_LIB_SUFFIX}) + install(FILES ${OPENCV_DYN_LIBS} DESTINATION lib) + endif() + + # FlyCV + if(ENABLE_FLYCV) + file(GLOB_RECURSE ALL_FLYCV_DYN_LIBS ${FLYCV_LIB_DIR}/${DYN_LIB_SUFFIX}) + list(REMOVE_ITEM ALL_DEPS_DYN_LIBS ${ALL_FLYCV_DYN_LIBS}) + if(ANDROID AND (NOT WITH_ANDROID_FLYCV_STATIC)) + install(FILES ${ALL_FLYCV_DYN_LIBS} DESTINATION lib) + endif() + endif() +endif() + +if(ENABLE_OPENVINO_BACKEND) + # need plugins.xml for openvino backend + set(OPENVINO_RUNTIME_BIN_DIR ${OPENVINO_DIR}/bin) + file(GLOB OPENVINO_PLUGIN_XML ${OPENVINO_RUNTIME_BIN_DIR}/*.xml) + install(FILES ${OPENVINO_PLUGIN_XML} DESTINATION lib) +endif() + +# Install other libraries +install(FILES ${ALL_NEED_DYN_LIBS} DESTINATION lib) +install(FILES ${ALL_DEPS_DYN_LIBS} DESTINATION lib) diff --git a/runtime/cmake/gflags.cmake b/runtime/cmake/gflags.cmake new file mode 100644 index 00000000..aa0248ba --- /dev/null +++ b/runtime/cmake/gflags.cmake @@ -0,0 +1,14 @@ +include(FetchContent) + +FetchContent_Declare( + gflags + URL https://paddleaudio.bj.bcebos.com/build/gflag-2.2.2.zip + URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5 +) +FetchContent_MakeAvailable(gflags) + +# openfst need +include_directories(${gflags_BINARY_DIR}/include) +link_directories(${gflags_BINARY_DIR}) + +#install(FILES ${gflags_BINARY_DIR}/libgflags_nothreads.a DESTINATION lib) diff --git a/runtime/cmake/glog.cmake b/runtime/cmake/glog.cmake new file mode 100644 index 00000000..6c38963a --- /dev/null +++ b/runtime/cmake/glog.cmake @@ -0,0 +1,35 @@ +include(FetchContent) + +if(ANDROID) +else() # UNIX + add_definitions(-DWITH_GLOG) + FetchContent_Declare( + glog + URL https://paddleaudio.bj.bcebos.com/build/glog-0.4.0.zip + URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc + CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_CXX_FLAGS=${GLOG_CMAKE_CXX_FLAGS} + -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE} + -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG} + -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} + -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG} + -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE} + -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DWITH_GFLAGS=OFF + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} + ${EXTERNAL_OPTIONAL_ARGS} + ) + set(BUILD_TESTING OFF) + FetchContent_MakeAvailable(glog) + include_directories(${glog_BINARY_DIR} ${glog_SOURCE_DIR}/src) +endif() + + +if(ANDROID) + add_library(extern_glog INTERFACE) + add_dependencies(extern_glog gflags) +else() # UNIX + add_library(extern_glog ALIAS glog) + add_dependencies(glog gflags) +endif() \ No newline at end of file diff --git a/runtime/cmake/gtest.cmake b/runtime/cmake/gtest.cmake new file mode 100644 index 00000000..a311721f --- /dev/null +++ b/runtime/cmake/gtest.cmake @@ -0,0 +1,27 @@ + +include(FetchContent) + +if(ANDROID) +else() # UNIX + FetchContent_Declare( + gtest + URL https://paddleaudio.bj.bcebos.com/build/gtest-release-1.11.0.zip + URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a + ) + FetchContent_MakeAvailable(gtest) + + include_directories(${gtest_BINARY_DIR} ${gtest_SOURCE_DIR}/src) +endif() + + + +if(ANDROID) + add_library(extern_gtest INTERFACE) +else() # UNIX + add_dependencies(gtest gflags extern_glog) + add_library(extern_gtest ALIAS gtest) +endif() + +if(WITH_TESTING) + enable_testing() +endif() diff --git a/speechx/cmake/kenlm.cmake b/runtime/cmake/kenlm.cmake similarity index 100% rename from speechx/cmake/kenlm.cmake rename to runtime/cmake/kenlm.cmake diff --git a/speechx/cmake/libsndfile.cmake b/runtime/cmake/libsndfile.cmake similarity index 100% rename from speechx/cmake/libsndfile.cmake rename to runtime/cmake/libsndfile.cmake diff --git a/speechx/cmake/openblas.cmake b/runtime/cmake/openblas.cmake similarity index 100% rename from speechx/cmake/openblas.cmake rename to runtime/cmake/openblas.cmake diff --git a/speechx/cmake/openfst.cmake b/runtime/cmake/openfst.cmake similarity index 69% rename from speechx/cmake/openfst.cmake rename to runtime/cmake/openfst.cmake index 07c33a74..42299c88 100644 --- a/speechx/cmake/openfst.cmake +++ b/runtime/cmake/openfst.cmake @@ -1,8 +1,8 @@ -include(FetchContent) set(openfst_PREFIX_DIR ${fc_patch}/openfst) set(openfst_SOURCE_DIR ${fc_patch}/openfst-src) set(openfst_BINARY_DIR ${fc_patch}/openfst-build) +include(FetchContent) # openfst Acknowledgments: #Cyril Allauzen, Michael Riley, Johan Schalkwyk, Wojciech Skut and Mehryar Mohri, #"OpenFst: A General and Efficient Weighted Finite-State Transducer Library", @@ -10,18 +10,33 @@ set(openfst_BINARY_DIR ${fc_patch}/openfst-build) #Application of Automata, (CIAA 2007), volume 4783 of Lecture Notes in #Computer Science, pages 11-23. Springer, 2007. http://www.openfst.org. +set(EXTERNAL_PROJECT_LOG_ARGS + LOG_DOWNLOAD 1 # Wrap download in script to log output + LOG_UPDATE 1 # Wrap update in script to log output + LOG_CONFIGURE 1# Wrap configure in script to log output + LOG_BUILD 1 # Wrap build in script to log output + LOG_TEST 1 # Wrap test in script to log output + LOG_INSTALL 1 # Wrap install in script to log output +) + ExternalProject_Add(openfst URL https://paddleaudio.bj.bcebos.com/build/openfst_1.7.2.zip URL_HASH SHA256=ffc56931025579a8af3515741c0f3b0fc3a854c023421472c07ca0c6389c75e6 + ${EXTERNAL_PROJECT_LOG_ARGS} PREFIX ${openfst_PREFIX_DIR} SOURCE_DIR ${openfst_SOURCE_DIR} BINARY_DIR ${openfst_BINARY_DIR} + BUILD_ALWAYS 0 CONFIGURE_COMMAND ${openfst_SOURCE_DIR}/configure --prefix=${openfst_PREFIX_DIR} "CPPFLAGS=-I${gflags_BINARY_DIR}/include -I${glog_SOURCE_DIR}/src -I${glog_BINARY_DIR}" "LDFLAGS=-L${gflags_BINARY_DIR} -L${glog_BINARY_DIR}" - "LIBS=-lgflags_nothreads -lglog -lpthread" + "LIBS=-lgflags_nothreads -lglog -lpthread -fPIC" COMMAND ${CMAKE_COMMAND} -E copy_directory ${PROJECT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR} BUILD_COMMAND make -j 4 ) link_directories(${openfst_PREFIX_DIR}/lib) include_directories(${openfst_PREFIX_DIR}/include) + + +message(STATUS "OpenFST inc dir: ${openfst_PREFIX_DIR}/include") +message(STATUS "OpenFST lib dir: ${openfst_PREFIX_DIR}/lib") diff --git a/speechx/cmake/paddleinference.cmake b/runtime/cmake/paddleinference.cmake similarity index 100% rename from speechx/cmake/paddleinference.cmake rename to runtime/cmake/paddleinference.cmake diff --git a/runtime/cmake/pybind.cmake b/runtime/cmake/pybind.cmake new file mode 100644 index 00000000..0ce1f57f --- /dev/null +++ b/runtime/cmake/pybind.cmake @@ -0,0 +1,42 @@ +#the pybind11 is from:https://github.com/pybind/pybind11 +# Copyright (c) 2016 Wenzel Jakob , All rights reserved. + +SET(PYBIND_ZIP "v2.10.0.zip") +SET(LOCAL_PYBIND_ZIP ${FETCHCONTENT_BASE_DIR}/${PYBIND_ZIP}) +SET(PYBIND_SRC ${FETCHCONTENT_BASE_DIR}/pybind11) +SET(DOWNLOAD_URL "https://paddleaudio.bj.bcebos.com/build/v2.10.0.zip") +SET(PYBIND_TIMEOUT 600 CACHE STRING "Timeout in seconds when downloading pybind.") + +IF(NOT EXISTS ${LOCAL_PYBIND_ZIP}) + FILE(DOWNLOAD ${DOWNLOAD_URL} + ${LOCAL_PYBIND_ZIP} + TIMEOUT ${PYBIND_TIMEOUT} + STATUS ERR + SHOW_PROGRESS + ) + + IF(ERR EQUAL 0) + MESSAGE(STATUS "download pybind success") + ELSE() + MESSAGE(FATAL_ERROR "download pybind fail") + ENDIF() +ENDIF() + +IF(NOT EXISTS ${PYBIND_SRC}) + EXECUTE_PROCESS( + COMMAND ${CMAKE_COMMAND} -E tar xfz ${LOCAL_PYBIND_ZIP} + WORKING_DIRECTORY ${FETCHCONTENT_BASE_DIR} + RESULT_VARIABLE tar_result + ) + + file(RENAME ${FETCHCONTENT_BASE_DIR}/pybind11-2.10.0 ${PYBIND_SRC}) + + IF (tar_result MATCHES 0) + MESSAGE(STATUS "unzip pybind success") + ELSE() + MESSAGE(FATAL_ERROR "unzip pybind fail") + ENDIF() + +ENDIF() + +include_directories(${PYBIND_SRC}/include) diff --git a/runtime/cmake/summary.cmake b/runtime/cmake/summary.cmake new file mode 100644 index 00000000..95ee324a --- /dev/null +++ b/runtime/cmake/summary.cmake @@ -0,0 +1,64 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +function(pps_summary) + message(STATUS "") + message(STATUS "*************PaddleSpeech Building Summary**********") + message(STATUS " PPS_VERSION : ${PPS_VERSION}") + message(STATUS " CMake version : ${CMAKE_VERSION}") + message(STATUS " CMake command : ${CMAKE_COMMAND}") + message(STATUS " UNIX : ${UNIX}") + message(STATUS " ANDROID : ${ANDROID}") + message(STATUS " System : ${CMAKE_SYSTEM_NAME}") + message(STATUS " C++ compiler : ${CMAKE_CXX_COMPILER}") + message(STATUS " C++ compiler version : ${CMAKE_CXX_COMPILER_VERSION}") + message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}") + message(STATUS " Build type : ${CMAKE_BUILD_TYPE}") + message(STATUS " BUILD_SHARED_LIBS : ${BUILD_SHARED_LIBS}") + get_directory_property(tmp DIRECTORY ${PROJECT_SOURCE_DIR} COMPILE_DEFINITIONS) + message(STATUS " Compile definitions : ${tmp}") + message(STATUS " CMAKE_PREFIX_PATH : ${CMAKE_PREFIX_PATH}") + message(STATUS " CMAKE_CURRENT_BINARY_DIR : ${CMAKE_CURRENT_BINARY_DIR}") + message(STATUS " CMAKE_INSTALL_PREFIX : ${CMAKE_INSTALL_PREFIX}") + message(STATUS " CMAKE_INSTALL_LIBDIR : ${CMAKE_INSTALL_LIBDIR}") + message(STATUS " CMAKE_MODULE_PATH : ${CMAKE_MODULE_PATH}") + message(STATUS " CMAKE_SYSTEM_NAME : ${CMAKE_SYSTEM_NAME}") + message(STATUS "") + + message(STATUS " WITH_ASR : ${WITH_ASR}") + message(STATUS " WITH_CLS : ${WITH_CLS}") + message(STATUS " WITH_VAD : ${WITH_VAD}") + message(STATUS " WITH_GPU : ${WITH_GPU}") + message(STATUS " WITH_TESTING : ${WITH_TESTING}") + message(STATUS " WITH_PROFILING : ${WITH_PROFILING}") + message(STATUS " FASTDEPLOY_INSTALL_DIR : ${FASTDEPLOY_INSTALL_DIR}") + message(STATUS " FASTDEPLOY_INCS : ${FASTDEPLOY_INCS}") + message(STATUS " FASTDEPLOY_LIBS : ${FASTDEPLOY_LIBS}") + if(WITH_GPU) + message(STATUS " CUDA_DIRECTORY : ${CUDA_DIRECTORY}") + endif() + + if(ANDROID) + message(STATUS " ANDROID_ABI : ${ANDROID_ABI}") + message(STATUS " ANDROID_PLATFORM : ${ANDROID_PLATFORM}") + message(STATUS " ANDROID_NDK : ${ANDROID_NDK}") + message(STATUS " ANDROID_NDK_VERSION : ${CMAKE_ANDROID_NDK_VERSION}") + endif() + if (WITH_ASR) + message(STATUS " Python executable : ${PYTHON_EXECUTABLE}") + message(STATUS " Python includes : ${PYTHON_INCLUDE_DIR}") + endif() +endfunction() + +pps_summary() \ No newline at end of file diff --git a/runtime/cmake/system.cmake b/runtime/cmake/system.cmake new file mode 100644 index 00000000..580e07bb --- /dev/null +++ b/runtime/cmake/system.cmake @@ -0,0 +1,106 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Detects the OS and sets appropriate variables. +# CMAKE_SYSTEM_NAME only give us a coarse-grained name of the OS CMake is +# building for, but the host processor name like centos is necessary +# in some scenes to distinguish system for customization. +# +# for instance, protobuf libs path is /lib64 +# on CentOS, but /lib on other systems. + +if(UNIX AND NOT APPLE) + # except apple from nix*Os family + set(LINUX TRUE) +endif() + +if(WIN32) + set(HOST_SYSTEM "win32") +else() + if(APPLE) + set(HOST_SYSTEM "macosx") + exec_program( + sw_vers ARGS + -productVersion + OUTPUT_VARIABLE HOST_SYSTEM_VERSION) + string(REGEX MATCH "[0-9]+.[0-9]+" MACOS_VERSION "${HOST_SYSTEM_VERSION}") + if(NOT DEFINED $ENV{MACOSX_DEPLOYMENT_TARGET}) + # Set cache variable - end user may change this during ccmake or cmake-gui configure. + set(CMAKE_OSX_DEPLOYMENT_TARGET + ${MACOS_VERSION} + CACHE + STRING + "Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value." + ) + endif() + set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") + else() + + if(EXISTS "/etc/issue") + file(READ "/etc/issue" LINUX_ISSUE) + if(LINUX_ISSUE MATCHES "CentOS") + set(HOST_SYSTEM "centos") + elseif(LINUX_ISSUE MATCHES "Debian") + set(HOST_SYSTEM "debian") + elseif(LINUX_ISSUE MATCHES "Ubuntu") + set(HOST_SYSTEM "ubuntu") + elseif(LINUX_ISSUE MATCHES "Red Hat") + set(HOST_SYSTEM "redhat") + elseif(LINUX_ISSUE MATCHES "Fedora") + set(HOST_SYSTEM "fedora") + endif() + + string(REGEX MATCH "(([0-9]+)\\.)+([0-9]+)" HOST_SYSTEM_VERSION + "${LINUX_ISSUE}") + endif() + + if(EXISTS "/etc/redhat-release") + file(READ "/etc/redhat-release" LINUX_ISSUE) + if(LINUX_ISSUE MATCHES "CentOS") + set(HOST_SYSTEM "centos") + endif() + endif() + + if(NOT HOST_SYSTEM) + set(HOST_SYSTEM ${CMAKE_SYSTEM_NAME}) + endif() + + endif() +endif() + +# query number of logical cores +cmake_host_system_information(RESULT CPU_CORES QUERY NUMBER_OF_LOGICAL_CORES) + +mark_as_advanced(HOST_SYSTEM CPU_CORES) + +message( + STATUS + "Found Paddle host system: ${HOST_SYSTEM}, version: ${HOST_SYSTEM_VERSION}") +message(STATUS "Found Paddle host system's CPU: ${CPU_CORES} cores") + +# external dependencies log output +set(EXTERNAL_PROJECT_LOG_ARGS + LOG_DOWNLOAD + 0 # Wrap download in script to log output + LOG_UPDATE + 1 # Wrap update in script to log output + LOG_CONFIGURE + 1 # Wrap configure in script to log output + LOG_BUILD + 0 # Wrap build in script to log output + LOG_TEST + 1 # Wrap test in script to log output + LOG_INSTALL + 0 # Wrap install in script to log output +) \ No newline at end of file diff --git a/speechx/speechx/kaldi/.gitkeep b/runtime/docker/.gitkeep similarity index 100% rename from speechx/speechx/kaldi/.gitkeep rename to runtime/docker/.gitkeep diff --git a/runtime/engine/CMakeLists.txt b/runtime/engine/CMakeLists.txt new file mode 100644 index 00000000..d64df648 --- /dev/null +++ b/runtime/engine/CMakeLists.txt @@ -0,0 +1,22 @@ +project(speechx LANGUAGES CXX) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/kaldi) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/common) + +add_subdirectory(kaldi) +add_subdirectory(common) + +if(WITH_ASR) + add_subdirectory(asr) +endif() + +if(WITH_CLS) + add_subdirectory(audio_classification) +endif() + +if(WITH_VAD) + add_subdirectory(vad) +endif() + +add_subdirectory(codelab) diff --git a/runtime/engine/asr/CMakeLists.txt b/runtime/engine/asr/CMakeLists.txt new file mode 100644 index 00000000..ff4cdecb --- /dev/null +++ b/runtime/engine/asr/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +project(ASR LANGUAGES CXX) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/server) + +add_subdirectory(decoder) +add_subdirectory(recognizer) +add_subdirectory(nnet) +add_subdirectory(server) diff --git a/runtime/engine/asr/decoder/CMakeLists.txt b/runtime/engine/asr/decoder/CMakeLists.txt new file mode 100644 index 00000000..2a20f446 --- /dev/null +++ b/runtime/engine/asr/decoder/CMakeLists.txt @@ -0,0 +1,24 @@ +set(srcs) +list(APPEND srcs + ctc_prefix_beam_search_decoder.cc + ctc_tlg_decoder.cc +) + +add_library(decoder STATIC ${srcs}) +target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder) + +# test +set(TEST_BINS + ctc_prefix_beam_search_decoder_main + ctc_tlg_decoder_main +) + +foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl) +endforeach() + diff --git a/speechx/speechx/decoder/common.h b/runtime/engine/asr/decoder/common.h similarity index 100% rename from speechx/speechx/decoder/common.h rename to runtime/engine/asr/decoder/common.h diff --git a/speechx/speechx/decoder/ctc_beam_search_opt.h b/runtime/engine/asr/decoder/ctc_beam_search_opt.h similarity index 52% rename from speechx/speechx/decoder/ctc_beam_search_opt.h rename to runtime/engine/asr/decoder/ctc_beam_search_opt.h index f4a81b3a..4c145370 100644 --- a/speechx/speechx/decoder/ctc_beam_search_opt.h +++ b/runtime/engine/asr/decoder/ctc_beam_search_opt.h @@ -22,51 +22,22 @@ namespace ppspeech { struct CTCBeamSearchOptions { // common int blank; - - // ds2 - std::string dict_file; - std::string lm_path; - int beam_size; - BaseFloat alpha; - BaseFloat beta; - BaseFloat cutoff_prob; - int cutoff_top_n; - int num_proc_bsearch; + std::string word_symbol_table; // u2 int first_beam_size; int second_beam_size; + CTCBeamSearchOptions() : blank(0), - dict_file("vocab.txt"), - lm_path(""), - beam_size(300), - alpha(1.9f), - beta(5.0), - cutoff_prob(0.99f), - cutoff_top_n(40), - num_proc_bsearch(10), + word_symbol_table("vocab.txt"), first_beam_size(10), second_beam_size(10) {} void Register(kaldi::OptionsItf* opts) { - std::string module = "Ds2BeamSearchConfig: "; - opts->Register("dict", &dict_file, module + "vocab file path."); - opts->Register( - "lm-path", &lm_path, module + "ngram language model path."); - opts->Register("alpha", &alpha, module + "alpha"); - opts->Register("beta", &beta, module + "beta"); - opts->Register("beam-size", - &beam_size, - module + "beam size for beam search method"); - opts->Register("cutoff-prob", &cutoff_prob, module + "cutoff probs"); - opts->Register("cutoff-top-n", &cutoff_top_n, module + "cutoff top n"); - opts->Register( - "num-proc-bsearch", &num_proc_bsearch, module + "num proc bsearch"); - + std::string module = "CTCBeamSearchOptions: "; + opts->Register("word_symbol_table", &word_symbol_table, module + "vocab file path."); opts->Register("blank", &blank, "blank id, default is 0."); - - module = "U2BeamSearchConfig: "; opts->Register( "first-beam-size", &first_beam_size, module + "first beam size."); opts->Register("second-beam-size", diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc similarity index 96% rename from speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc rename to runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc index 07e8e560..bf912af2 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.cc +++ b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.cc @@ -17,13 +17,12 @@ #include "decoder/ctc_prefix_beam_search_decoder.h" -#include "absl/strings/str_join.h" #include "base/common.h" #include "decoder/ctc_beam_search_opt.h" #include "decoder/ctc_prefix_beam_search_score.h" #include "utils/math.h" -#ifdef USE_PROFILING +#ifdef WITH_PROFILING #include "paddle/fluid/platform/profiler.h" using paddle::platform::RecordEvent; using paddle::platform::TracerEventType; @@ -31,11 +30,10 @@ using paddle::platform::TracerEventType; namespace ppspeech { -CTCPrefixBeamSearch::CTCPrefixBeamSearch(const std::string& vocab_path, - const CTCBeamSearchOptions& opts) +CTCPrefixBeamSearch::CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts) : opts_(opts) { unit_table_ = std::shared_ptr( - fst::SymbolTable::ReadText(vocab_path)); + fst::SymbolTable::ReadText(opts.word_symbol_table)); CHECK(unit_table_ != nullptr); Reset(); @@ -66,7 +64,6 @@ void CTCPrefixBeamSearch::Reset() { void CTCPrefixBeamSearch::InitDecoder() { Reset(); } - void CTCPrefixBeamSearch::AdvanceDecode( const std::shared_ptr& decodable) { double search_cost = 0.0; @@ -78,21 +75,21 @@ void CTCPrefixBeamSearch::AdvanceDecode( bool flag = decodable->FrameLikelihood(num_frame_decoded_, &frame_prob); feat_nnet_cost += timer.Elapsed(); if (flag == false) { - VLOG(3) << "decoder advance decode exit." << frame_prob.size(); + VLOG(2) << "decoder advance decode exit." << frame_prob.size(); break; } timer.Reset(); std::vector> likelihood; - likelihood.push_back(frame_prob); + likelihood.push_back(std::move(frame_prob)); AdvanceDecoding(likelihood); search_cost += timer.Elapsed(); - VLOG(2) << "num_frame_decoded_: " << num_frame_decoded_; + VLOG(1) << "num_frame_decoded_: " << num_frame_decoded_; } - VLOG(1) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost + VLOG(2) << "AdvanceDecode feat + forward cost: " << feat_nnet_cost << " sec."; - VLOG(1) << "AdvanceDecode search cost: " << search_cost << " sec."; + VLOG(2) << "AdvanceDecode search cost: " << search_cost << " sec."; } static bool PrefixScoreCompare( @@ -105,7 +102,7 @@ static bool PrefixScoreCompare( void CTCPrefixBeamSearch::AdvanceDecoding( const std::vector>& logp) { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event("CtcPrefixBeamSearch::AdvanceDecoding", TracerEventType::UserDefined, 1); diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h similarity index 94% rename from speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h rename to runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h index 5013246a..391b4073 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder.h +++ b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder.h @@ -27,8 +27,7 @@ namespace ppspeech { class ContextGraph; class CTCPrefixBeamSearch : public DecoderBase { public: - CTCPrefixBeamSearch(const std::string& vocab_path, - const CTCBeamSearchOptions& opts); + CTCPrefixBeamSearch(const CTCBeamSearchOptions& opts); ~CTCPrefixBeamSearch() {} SearchType Type() const { return SearchType::kPrefixBeamSearch; } @@ -45,7 +44,7 @@ class CTCPrefixBeamSearch : public DecoderBase { void FinalizeSearch(); - const std::shared_ptr VocabTable() const { + const std::shared_ptr WordSymbolTable() const override { return unit_table_; } @@ -57,7 +56,6 @@ class CTCPrefixBeamSearch : public DecoderBase { } const std::vector>& Times() const { return times_; } - protected: std::string GetBestPath() override; std::vector> GetNBestPath() override; diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc similarity index 86% rename from speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc rename to runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc index c59b1f2e..0935c6e6 100644 --- a/speechx/speechx/decoder/ctc_prefix_beam_search_decoder_main.cc +++ b/runtime/engine/asr/decoder/ctc_prefix_beam_search_decoder_main.cc @@ -12,18 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "absl/strings/str_split.h" #include "base/common.h" #include "decoder/ctc_prefix_beam_search_decoder.h" -#include "frontend/audio/data_cache.h" +#include "frontend/data_cache.h" #include "fst/symbol-table.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" +#include "nnet/nnet_producer.h" #include "nnet/u2_nnet.h" DEFINE_string(feature_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); -DEFINE_string(vocab_path, "", "vocab path"); +DEFINE_string(word_symbol_table, "", "vocab path"); DEFINE_string(model_path, "", "paddle nnet model"); @@ -40,7 +40,7 @@ using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; -// test ds2 online decoder by feeding speech feature +// test u2 online decoder by feeding speech feature int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -52,10 +52,10 @@ int main(int argc, char* argv[]) { CHECK_NE(FLAGS_result_wspecifier, ""); CHECK_NE(FLAGS_feature_rspecifier, ""); - CHECK_NE(FLAGS_vocab_path, ""); + CHECK_NE(FLAGS_word_symbol_table, ""); CHECK_NE(FLAGS_model_path, ""); LOG(INFO) << "model path: " << FLAGS_model_path; - LOG(INFO) << "Reading vocab table " << FLAGS_vocab_path; + LOG(INFO) << "Reading vocab table " << FLAGS_word_symbol_table; kaldi::SequentialBaseFloatMatrixReader feature_reader( FLAGS_feature_rspecifier); @@ -70,15 +70,18 @@ int main(int argc, char* argv[]) { // decodeable std::shared_ptr raw_data = std::make_shared(); + std::shared_ptr nnet_producer = + std::make_shared(nnet, raw_data, 1.0); std::shared_ptr decodable = - std::make_shared(nnet, raw_data); + std::make_shared(nnet_producer); // decoder ppspeech::CTCBeamSearchOptions opts; opts.blank = 0; opts.first_beam_size = 10; opts.second_beam_size = 10; - ppspeech::CTCPrefixBeamSearch decoder(FLAGS_vocab_path, opts); + opts.word_symbol_table = FLAGS_word_symbol_table; + ppspeech::CTCPrefixBeamSearch decoder(opts); int32 chunk_size = FLAGS_receptive_field_length + @@ -122,15 +125,14 @@ int main(int argc, char* argv[]) { } - kaldi::Vector feature_chunk(this_chunk_size * - feat_dim); + std::vector feature_chunk(this_chunk_size * + feat_dim); int32 start = chunk_idx * chunk_stride; for (int row_id = 0; row_id < this_chunk_size; ++row_id) { kaldi::SubVector feat_row(feature, start); - kaldi::SubVector feature_chunk_row( - feature_chunk.Data() + row_id * feat_dim, feat_dim); - - feature_chunk_row.CopyFromVec(feat_row); + std::memcpy(feature_chunk.data() + row_id * feat_dim, + feat_row.Data(), + feat_dim * sizeof(kaldi::BaseFloat)); ++start; } diff --git a/speechx/speechx/decoder/ctc_prefix_beam_search_score.h b/runtime/engine/asr/decoder/ctc_prefix_beam_search_score.h similarity index 100% rename from speechx/speechx/decoder/ctc_prefix_beam_search_score.h rename to runtime/engine/asr/decoder/ctc_prefix_beam_search_score.h diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.cc b/runtime/engine/asr/decoder/ctc_tlg_decoder.cc similarity index 62% rename from speechx/speechx/decoder/ctc_tlg_decoder.cc rename to runtime/engine/asr/decoder/ctc_tlg_decoder.cc index 2c2b6d3c..51ded499 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.cc +++ b/runtime/engine/asr/decoder/ctc_tlg_decoder.cc @@ -13,12 +13,14 @@ // limitations under the License. #include "decoder/ctc_tlg_decoder.h" + namespace ppspeech { -TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { - fst_.reset(fst::Fst::Read(opts.fst_path)); +TLGDecoder::TLGDecoder(TLGDecoderOptions opts) : opts_(opts) { + fst_ = opts.fst_ptr; CHECK(fst_ != nullptr); + CHECK(!opts.word_symbol_table.empty()); word_symbol_table_.reset( fst::SymbolTable::ReadText(opts.word_symbol_table)); @@ -29,6 +31,11 @@ TLGDecoder::TLGDecoder(TLGDecoderOptions opts) { void TLGDecoder::Reset() { decoder_->InitDecoding(); + hypotheses_.clear(); + likelihood_.clear(); + olabels_.clear(); + times_.clear(); + num_frame_decoded_ = 0; return; } @@ -68,14 +75,52 @@ std::string TLGDecoder::GetPartialResult() { return words; } +void TLGDecoder::FinalizeSearch() { + decoder_->FinalizeDecoding(); + kaldi::CompactLattice clat; + decoder_->GetLattice(&clat, true); + kaldi::Lattice lat, nbest_lat; + fst::ConvertLattice(clat, &lat); + fst::ShortestPath(lat, &nbest_lat, opts_.nbest); + std::vector nbest_lats; + fst::ConvertNbestToVector(nbest_lat, &nbest_lats); + + hypotheses_.clear(); + hypotheses_.reserve(nbest_lats.size()); + likelihood_.clear(); + likelihood_.reserve(nbest_lats.size()); + times_.clear(); + times_.reserve(nbest_lats.size()); + for (auto lat : nbest_lats) { + kaldi::LatticeWeight weight; + std::vector hypothese; + std::vector time; + std::vector alignment; + std::vector words_id; + fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight); + int idx = 0; + for (; idx < alignment.size() - 1; ++idx) { + if (alignment[idx] == 0) continue; + if (alignment[idx] != alignment[idx + 1]) { + hypothese.push_back(alignment[idx] - 1); + time.push_back(idx); // fake time, todo later + } + } + hypothese.push_back(alignment[idx] - 1); + time.push_back(idx); // fake time, todo later + hypotheses_.push_back(hypothese); + times_.push_back(time); + olabels_.push_back(words_id); + likelihood_.push_back(-(weight.Value2() + weight.Value1())); + } +} + std::string TLGDecoder::GetFinalBestPath() { if (num_frame_decoded_ == 0) { // Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call // BestPathEnd if no frames were decoded.") return std::string(""); } - - decoder_->FinalizeDecoding(); kaldi::Lattice lat; kaldi::LatticeWeight weight; std::vector alignment; diff --git a/speechx/speechx/decoder/ctc_tlg_decoder.h b/runtime/engine/asr/decoder/ctc_tlg_decoder.h similarity index 67% rename from speechx/speechx/decoder/ctc_tlg_decoder.h rename to runtime/engine/asr/decoder/ctc_tlg_decoder.h index 8be69dad..80896361 100644 --- a/speechx/speechx/decoder/ctc_tlg_decoder.h +++ b/runtime/engine/asr/decoder/ctc_tlg_decoder.h @@ -18,13 +18,14 @@ #include "decoder/decoder_itf.h" #include "kaldi/decoder/lattice-faster-online-decoder.h" #include "util/parse-options.h" +#include "utils/file_utils.h" - -DECLARE_string(graph_path); DECLARE_string(word_symbol_table); +DECLARE_string(graph_path); DECLARE_int32(max_active); DECLARE_double(beam); DECLARE_double(lattice_beam); +DECLARE_int32(nbest); namespace ppspeech { @@ -33,17 +34,27 @@ struct TLGDecoderOptions { // todo remove later, add into decode resource std::string word_symbol_table; std::string fst_path; + std::shared_ptr> fst_ptr; + int nbest; + + TLGDecoderOptions() : word_symbol_table(""), fst_path(""), fst_ptr(nullptr), nbest(10) {} static TLGDecoderOptions InitFromFlags() { TLGDecoderOptions decoder_opts; decoder_opts.word_symbol_table = FLAGS_word_symbol_table; decoder_opts.fst_path = FLAGS_graph_path; LOG(INFO) << "fst path: " << decoder_opts.fst_path; - LOG(INFO) << "fst symbole table: " << decoder_opts.word_symbol_table; + LOG(INFO) << "symbole table: " << decoder_opts.word_symbol_table; + + if (!decoder_opts.fst_path.empty()) { + CHECK(FileExists(decoder_opts.fst_path)); + decoder_opts.fst_ptr.reset(fst::Fst::Read(FLAGS_graph_path)); + } decoder_opts.opts.max_active = FLAGS_max_active; decoder_opts.opts.beam = FLAGS_beam; decoder_opts.opts.lattice_beam = FLAGS_lattice_beam; + decoder_opts.nbest = FLAGS_nbest; LOG(INFO) << "LatticeFasterDecoder max active: " << decoder_opts.opts.max_active; LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam; @@ -59,20 +70,38 @@ class TLGDecoder : public DecoderBase { explicit TLGDecoder(TLGDecoderOptions opts); ~TLGDecoder() = default; - void InitDecoder(); - void Reset(); + void InitDecoder() override; + void Reset() override; void AdvanceDecode( - const std::shared_ptr& decodable); + const std::shared_ptr& decodable) override; void Decode(); std::string GetFinalBestPath() override; std::string GetPartialResult() override; + const std::shared_ptr WordSymbolTable() const override { + return word_symbol_table_; + } + int DecodeLikelihoods(const std::vector>& probs, const std::vector& nbest_words); + void FinalizeSearch() override; + const std::vector>& Inputs() const override { + return hypotheses_; + } + const std::vector>& Outputs() const override { + return olabels_; + } // outputs_; } + const std::vector& Likelihood() const override { + return likelihood_; + } + const std::vector>& Times() const override { + return times_; + } + protected: std::string GetBestPath() override { CHECK(false); @@ -90,10 +119,17 @@ class TLGDecoder : public DecoderBase { private: void AdvanceDecoding(kaldi::DecodableInterface* decodable); + int num_frame_decoded_; + std::vector> hypotheses_; + std::vector> olabels_; + std::vector likelihood_; + std::vector> times_; + std::shared_ptr decoder_; std::shared_ptr> fst_; std::shared_ptr word_symbol_table_; + TLGDecoderOptions opts_; }; -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/decoder/nnet_logprob_decoder_main.cc b/runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc similarity index 50% rename from speechx/speechx/decoder/nnet_logprob_decoder_main.cc rename to runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc index e0acbe77..dcd18b81 100644 --- a/speechx/speechx/decoder/nnet_logprob_decoder_main.cc +++ b/runtime/engine/asr/decoder/ctc_tlg_decoder_main.cc @@ -14,21 +14,24 @@ // todo refactor, repalce with gtest -#include "base/flags.h" -#include "base/log.h" -#include "decoder/ctc_beam_search_decoder.h" +#include "base/common.h" +#include "decoder/ctc_tlg_decoder.h" +#include "decoder/param.h" +#include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" +#include "nnet/nnet_producer.h" + + +DEFINE_string(nnet_prob_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); -DEFINE_string(nnet_prob_respecifier, "", "test nnet prob rspecifier"); -DEFINE_string(dict_file, "vocab.txt", "vocabulary of lm"); -DEFINE_string(lm_path, "lm.klm", "language model"); using kaldi::BaseFloat; using kaldi::Matrix; using std::vector; -// test decoder by feeding nnet posterior probability +// test TLG decoder by feeding speech feature. int main(int argc, char* argv[]) { gflags::SetUsageMessage("Usage:"); gflags::ParseCommandLineFlags(&argc, &argv, false); @@ -36,41 +39,51 @@ int main(int argc, char* argv[]) { google::InstallFailureSignalHandler(); FLAGS_logtostderr = 1; - kaldi::SequentialBaseFloatMatrixReader likelihood_reader( - FLAGS_nnet_prob_respecifier); - std::string dict_file = FLAGS_dict_file; - std::string lm_path = FLAGS_lm_path; - LOG(INFO) << "dict path: " << dict_file; - LOG(INFO) << "lm path: " << lm_path; + kaldi::SequentialBaseFloatMatrixReader nnet_prob_reader( + FLAGS_nnet_prob_rspecifier); + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); int32 num_done = 0, num_err = 0; - ppspeech::CTCBeamSearchOptions opts; - opts.dict_file = dict_file; - opts.lm_path = lm_path; - ppspeech::CTCBeamSearch decoder(opts); + ppspeech::TLGDecoderOptions opts = + ppspeech::TLGDecoderOptions::InitFromFlags(); + opts.opts.beam = 15.0; + opts.opts.lattice_beam = 7.5; + ppspeech::TLGDecoder decoder(opts); + + ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); + std::shared_ptr nnet_producer = + std::make_shared(nullptr, nullptr, 1.0); std::shared_ptr decodable( - new ppspeech::Decodable(nullptr, nullptr)); + new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale)); decoder.InitDecoder(); + kaldi::Timer timer; - for (; !likelihood_reader.Done(); likelihood_reader.Next()) { - string utt = likelihood_reader.Key(); - const kaldi::Matrix likelihood = likelihood_reader.Value(); - LOG(INFO) << "process utt: " << utt; - LOG(INFO) << "rows: " << likelihood.NumRows(); - LOG(INFO) << "cols: " << likelihood.NumCols(); - decodable->Acceptlikelihood(likelihood); + for (; !nnet_prob_reader.Done(); nnet_prob_reader.Next()) { + string utt = nnet_prob_reader.Key(); + kaldi::Matrix prob = nnet_prob_reader.Value(); + decodable->Acceptlikelihood(prob); decoder.AdvanceDecode(decodable); std::string result; result = decoder.GetFinalBestPath(); - KALDI_LOG << " the result of " << utt << " is " << result; decodable->Reset(); decoder.Reset(); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + KALDI_LOG << " the result of " << utt << " is empty"; + continue; + } + KALDI_LOG << " the result of " << utt << " is " << result; + result_writer.Write(utt, result); ++num_done; } + double elapsed = timer.Elapsed(); + KALDI_LOG << " cost:" << elapsed << " s"; + KALDI_LOG << "Done " << num_done << " utterances, " << num_err << " with errors."; return (num_done != 0 ? 0 : 1); diff --git a/speechx/speechx/decoder/decoder_itf.h b/runtime/engine/asr/decoder/decoder_itf.h similarity index 79% rename from speechx/speechx/decoder/decoder_itf.h rename to runtime/engine/asr/decoder/decoder_itf.h index 2289b317..cb7717e8 100644 --- a/speechx/speechx/decoder/decoder_itf.h +++ b/runtime/engine/asr/decoder/decoder_itf.h @@ -1,4 +1,3 @@ - // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +15,7 @@ #pragma once #include "base/common.h" +#include "fst/symbol-table.h" #include "kaldi/decoder/decodable-itf.h" namespace ppspeech { @@ -41,6 +41,14 @@ class DecoderInterface { virtual std::string GetPartialResult() = 0; + virtual const std::shared_ptr WordSymbolTable() const = 0; + virtual void FinalizeSearch() = 0; + + virtual const std::vector>& Inputs() const = 0; + virtual const std::vector>& Outputs() const = 0; + virtual const std::vector& Likelihood() const = 0; + virtual const std::vector>& Times() const = 0; + protected: // virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0; diff --git a/speechx/speechx/decoder/param.h b/runtime/engine/asr/decoder/param.h similarity index 73% rename from speechx/speechx/decoder/param.h rename to runtime/engine/asr/decoder/param.h index ebdd7119..0cad75bf 100644 --- a/speechx/speechx/decoder/param.h +++ b/runtime/engine/asr/decoder/param.h @@ -15,8 +15,6 @@ #pragma once #include "base/common.h" -#include "decoder/ctc_beam_search_decoder.h" -#include "decoder/ctc_tlg_decoder.h" // feature DEFINE_bool(use_fbank, false, "False for fbank; or linear feature"); @@ -37,36 +35,22 @@ DEFINE_int32(subsampling_rate, "two CNN(kernel=3) module downsampling rate."); DEFINE_int32(nnet_decoder_chunk, 1, "paddle nnet forward chunk"); - // nnet -DEFINE_string(vocab_path, "", "nnet vocab path."); DEFINE_string(model_path, "avg_1.jit.pdmodel", "paddle nnet model"); -DEFINE_string(param_path, "avg_1.jit.pdiparams", "paddle nnet model param"); -DEFINE_string( - model_input_names, - "audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box", - "model input names"); -DEFINE_string(model_output_names, - "softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0", - "model output names"); -DEFINE_string(model_cache_names, - "chunk_state_h_box,chunk_state_c_box", - "model cache names"); -DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes"); - +#ifdef USE_ONNX +DEFINE_bool(with_onnx_model, false, "True mean the model path is onnx model path"); +#endif // decoder DEFINE_double(acoustic_scale, 1.0, "acoustic scale"); - -DEFINE_string(graph_path, "TLG", "decoder graph"); -DEFINE_string(word_symbol_table, "words.txt", "word symbol table"); +DEFINE_string(graph_path, "", "decoder graph"); +DEFINE_string(word_symbol_table, "", "word symbol table"); DEFINE_int32(max_active, 7500, "max active"); DEFINE_double(beam, 15.0, "decoder beam"); DEFINE_double(lattice_beam, 7.5, "decoder beam"); - +DEFINE_double(blank_threshold, 0.98, "blank skip threshold"); // DecodeOptions flags -// DEFINE_int32(chunk_size, -1, "decoding chunk size"); DEFINE_int32(num_left_chunks, -1, "left chunks in decoding"); DEFINE_double(ctc_weight, 0.5, diff --git a/runtime/engine/asr/nnet/CMakeLists.txt b/runtime/engine/asr/nnet/CMakeLists.txt new file mode 100644 index 00000000..1adcbfeb --- /dev/null +++ b/runtime/engine/asr/nnet/CMakeLists.txt @@ -0,0 +1,21 @@ +set(srcs decodable.cc nnet_producer.cc) + +list(APPEND srcs u2_nnet.cc) +if(WITH_ONNX) + list(APPEND srcs u2_onnx_nnet.cc) +endif() +add_library(nnet STATIC ${srcs}) +target_link_libraries(nnet utils) +if(WITH_ONNX) + target_link_libraries(nnet ${FASTDEPLOY_LIBS}) +endif() + +target_compile_options(nnet PUBLIC ${PADDLE_COMPILE_FLAGS}) +target_include_directories(nnet PUBLIC ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + +# test bin +#set(bin_name u2_nnet_main) +#add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +#target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) +#target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) +#target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS}) \ No newline at end of file diff --git a/speechx/speechx/nnet/decodable.cc b/runtime/engine/asr/nnet/decodable.cc similarity index 54% rename from speechx/speechx/nnet/decodable.cc rename to runtime/engine/asr/nnet/decodable.cc index 5fe2b984..a140c376 100644 --- a/speechx/speechx/nnet/decodable.cc +++ b/runtime/engine/asr/nnet/decodable.cc @@ -21,29 +21,25 @@ using kaldi::Matrix; using kaldi::Vector; using std::vector; -Decodable::Decodable(const std::shared_ptr& nnet, - const std::shared_ptr& frontend, +Decodable::Decodable(const std::shared_ptr& nnet_producer, kaldi::BaseFloat acoustic_scale) - : frontend_(frontend), - nnet_(nnet), + : nnet_producer_(nnet_producer), frame_offset_(0), frames_ready_(0), acoustic_scale_(acoustic_scale) {} // for debug void Decodable::Acceptlikelihood(const Matrix& likelihood) { - nnet_out_cache_ = likelihood; - frames_ready_ += likelihood.NumRows(); + nnet_producer_->Acceptlikelihood(likelihood); } - // return the size of frame have computed. int32 Decodable::NumFramesReady() const { return frames_ready_; } // frame idx is from 0 to frame_ready_ -1; bool Decodable::IsLastFrame(int32 frame) { - bool flag = EnsureFrameHaveComputed(frame); + EnsureFrameHaveComputed(frame); return frame >= frames_ready_; } @@ -64,32 +60,10 @@ bool Decodable::EnsureFrameHaveComputed(int32 frame) { bool Decodable::AdvanceChunk() { kaldi::Timer timer; - // read feats - Vector features; - if (frontend_ == NULL || frontend_->Read(&features) == false) { - // no feat or frontend_ not init. - VLOG(3) << "decodable exit;"; - return false; - } - CHECK_GE(frontend_->Dim(), 0); - VLOG(1) << "AdvanceChunk feat cost: " << timer.Elapsed() << " sec."; - VLOG(2) << "Forward in " << features.Dim() / frontend_->Dim() << " feats."; - - // forward feats - NnetOut out; - nnet_->FeedForward(features, frontend_->Dim(), &out); - int32& vocab_dim = out.vocab_dim; - Vector& logprobs = out.logprobs; - - VLOG(2) << "Forward out " << logprobs.Dim() / vocab_dim - << " decoder frames."; - // cache nnet outupts - nnet_out_cache_.Resize(logprobs.Dim() / vocab_dim, vocab_dim); - nnet_out_cache_.CopyRowsFromVec(logprobs); - - // update state, decoding frame. + bool flag = nnet_producer_->Read(&framelikelihood_); + if (flag == false) return false; frame_offset_ = frames_ready_; - frames_ready_ += nnet_out_cache_.NumRows(); + frames_ready_ += 1; VLOG(1) << "AdvanceChunk feat + forward cost: " << timer.Elapsed() << " sec."; return true; @@ -101,17 +75,17 @@ bool Decodable::AdvanceChunk(kaldi::Vector* logprobs, return false; } - int nrows = nnet_out_cache_.NumRows(); - CHECK(nrows == (frames_ready_ - frame_offset_)); - if (nrows <= 0) { + if (framelikelihood_.empty()) { LOG(WARNING) << "No new nnet out in cache."; return false; } - logprobs->Resize(nnet_out_cache_.NumRows() * nnet_out_cache_.NumCols()); - logprobs->CopyRowsFromMat(nnet_out_cache_); - - *vocab_dim = nnet_out_cache_.NumCols(); + size_t dim = framelikelihood_.size(); + logprobs->Resize(framelikelihood_.size()); + std::memcpy(logprobs->Data(), + framelikelihood_.data(), + dim * sizeof(kaldi::BaseFloat)); + *vocab_dim = framelikelihood_.size(); return true; } @@ -122,19 +96,8 @@ bool Decodable::FrameLikelihood(int32 frame, vector* likelihood) { return false; } - int nrows = nnet_out_cache_.NumRows(); - CHECK(nrows == (frames_ready_ - frame_offset_)); - int vocab_size = nnet_out_cache_.NumCols(); - likelihood->resize(vocab_size); - - for (int32 idx = 0; idx < vocab_size; ++idx) { - (*likelihood)[idx] = - nnet_out_cache_(frame - frame_offset_, idx) * acoustic_scale_; - - VLOG(4) << "nnet out: " << frame << " offset:" << frame_offset_ << " " - << nnet_out_cache_.NumRows() - << " logprob: " << nnet_out_cache_(frame - frame_offset_, idx); - } + CHECK_EQ(1, (frames_ready_ - frame_offset_)); + *likelihood = framelikelihood_; return true; } @@ -143,37 +106,31 @@ BaseFloat Decodable::LogLikelihood(int32 frame, int32 index) { return false; } - CHECK_LE(index, nnet_out_cache_.NumCols()); + CHECK_LE(index, framelikelihood_.size()); CHECK_LE(frame, frames_ready_); // the nnet output is prob ranther than log prob // the index - 1, because the ilabel BaseFloat logprob = 0.0; int32 frame_idx = frame - frame_offset_; - BaseFloat nnet_out = nnet_out_cache_(frame_idx, TokenId2NnetId(index)); - if (nnet_->IsLogProb()) { - logprob = nnet_out; - } else { - logprob = std::log(nnet_out + std::numeric_limits::epsilon()); - } - CHECK(!std::isnan(logprob) && !std::isinf(logprob)); + CHECK_EQ(frame_idx, 0); + logprob = framelikelihood_[TokenId2NnetId(index)]; return acoustic_scale_ * logprob; } void Decodable::Reset() { - if (frontend_ != nullptr) frontend_->Reset(); - if (nnet_ != nullptr) nnet_->Reset(); + if (nnet_producer_ != nullptr) nnet_producer_->Reset(); frame_offset_ = 0; frames_ready_ = 0; - nnet_out_cache_.Resize(0, 0); + framelikelihood_.clear(); } void Decodable::AttentionRescoring(const std::vector>& hyps, float reverse_weight, std::vector* rescoring_score) { kaldi::Timer timer; - nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score); + nnet_producer_->AttentionRescoring(hyps, reverse_weight, rescoring_score); VLOG(1) << "Attention Rescoring cost: " << timer.Elapsed() << " sec."; } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/nnet/decodable.h b/runtime/engine/asr/nnet/decodable.h similarity index 81% rename from speechx/speechx/nnet/decodable.h rename to runtime/engine/asr/nnet/decodable.h index dd7b329e..f6448670 100644 --- a/speechx/speechx/nnet/decodable.h +++ b/runtime/engine/asr/nnet/decodable.h @@ -12,11 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#pragma once + #include "base/common.h" -#include "frontend/audio/frontend_itf.h" #include "kaldi/decoder/decodable-itf.h" -#include "kaldi/matrix/kaldi-matrix.h" +#include "matrix/kaldi-matrix.h" #include "nnet/nnet_itf.h" +#include "nnet/nnet_producer.h" namespace ppspeech { @@ -24,12 +26,9 @@ struct DecodableOpts; class Decodable : public kaldi::DecodableInterface { public: - explicit Decodable(const std::shared_ptr& nnet, - const std::shared_ptr& frontend, + explicit Decodable(const std::shared_ptr& nnet_producer, kaldi::BaseFloat acoustic_scale = 1.0); - // void Init(DecodableOpts config); - // nnet logprob output, used by wfst virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index); @@ -57,23 +56,17 @@ class Decodable : public kaldi::DecodableInterface { void Reset(); - bool IsInputFinished() const { return frontend_->IsFinished(); } + bool IsInputFinished() const { return nnet_producer_->IsFinished(); } bool EnsureFrameHaveComputed(int32 frame); int32 TokenId2NnetId(int32 token_id); - std::shared_ptr Nnet() { return nnet_; } - // for offline test void Acceptlikelihood(const kaldi::Matrix& likelihood); private: - std::shared_ptr frontend_; - std::shared_ptr nnet_; - - // nnet outputs' cache - kaldi::Matrix nnet_out_cache_; + std::shared_ptr nnet_producer_; // the frame is nnet prob frame rather than audio feature frame // nnet frame subsample the feature frame @@ -85,6 +78,7 @@ class Decodable : public kaldi::DecodableInterface { // so use subsampled_frame int32 current_log_post_subsampled_offset_; int32 num_chunk_computed_; + std::vector framelikelihood_; kaldi::BaseFloat acoustic_scale_; }; diff --git a/speechx/speechx/nnet/nnet_itf.h b/runtime/engine/asr/nnet/nnet_itf.h similarity index 70% rename from speechx/speechx/nnet/nnet_itf.h rename to runtime/engine/asr/nnet/nnet_itf.h index a504cce5..ac105d11 100644 --- a/speechx/speechx/nnet/nnet_itf.h +++ b/runtime/engine/asr/nnet/nnet_itf.h @@ -15,7 +15,6 @@ #include "base/basic_types.h" #include "kaldi/base/kaldi-types.h" -#include "kaldi/matrix/kaldi-matrix.h" #include "kaldi/util/options-itf.h" DECLARE_int32(subsampling_rate); @@ -25,26 +24,20 @@ DECLARE_string(model_input_names); DECLARE_string(model_output_names); DECLARE_string(model_cache_names); DECLARE_string(model_cache_shapes); +#ifdef USE_ONNX +DECLARE_bool(with_onnx_model); +#endif namespace ppspeech { struct ModelOptions { // common int subsample_rate{1}; - int thread_num{1}; // predictor thread pool size for ds2; bool use_gpu{false}; std::string model_path; - - std::string param_path; - - // ds2 for inference - std::string input_names{}; - std::string output_names{}; - std::string cache_names{}; - std::string cache_shape{}; - bool switch_ir_optim{false}; - bool enable_fc_padding{false}; - bool enable_profile{false}; +#ifdef USE_ONNX + bool with_onnx_model{false}; +#endif static ModelOptions InitFromFlags() { ModelOptions opts; @@ -52,26 +45,17 @@ struct ModelOptions { LOG(INFO) << "subsampling rate: " << opts.subsample_rate; opts.model_path = FLAGS_model_path; LOG(INFO) << "model path: " << opts.model_path; - - opts.param_path = FLAGS_param_path; - LOG(INFO) << "param path: " << opts.param_path; - - LOG(INFO) << "DS2 param: "; - opts.cache_names = FLAGS_model_cache_names; - LOG(INFO) << " cache names: " << opts.cache_names; - opts.cache_shape = FLAGS_model_cache_shapes; - LOG(INFO) << " cache shape: " << opts.cache_shape; - opts.input_names = FLAGS_model_input_names; - LOG(INFO) << " input names: " << opts.input_names; - opts.output_names = FLAGS_model_output_names; - LOG(INFO) << " output names: " << opts.output_names; +#ifdef USE_ONNX + opts.with_onnx_model = FLAGS_with_onnx_model; + LOG(INFO) << "with onnx model: " << opts.with_onnx_model; +#endif return opts; } }; struct NnetOut { // nnet out. maybe logprob or prob. Almost time this is logprob. - kaldi::Vector logprobs; + std::vector logprobs; int32 vocab_dim; // nnet state. Only using in Attention model. @@ -89,7 +73,7 @@ class NnetInterface { // nnet do not cache feats, feats cached by frontend. // nnet cache model state, i.e. encoder_outs, att_cache, cnn_cache, // frame_offset. - virtual void FeedForward(const kaldi::Vector& features, + virtual void FeedForward(const std::vector& features, const int32& feature_dim, NnetOut* out) = 0; @@ -105,14 +89,14 @@ class NnetInterface { // using to get encoder outs. e.g. seq2seq with Attention model. virtual void EncoderOuts( - std::vector>* encoder_out) const = 0; + std::vector>* encoder_out) const = 0; }; class NnetBase : public NnetInterface { public: int SubsamplingRate() const { return subsampling_rate_; } - + virtual std::shared_ptr Clone() const = 0; protected: int subsampling_rate_{1}; }; diff --git a/runtime/engine/asr/nnet/nnet_producer.cc b/runtime/engine/asr/nnet/nnet_producer.cc new file mode 100644 index 00000000..529fae65 --- /dev/null +++ b/runtime/engine/asr/nnet/nnet_producer.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "nnet/nnet_producer.h" + +#include "matrix/kaldi-matrix.h" + +namespace ppspeech { + +using kaldi::BaseFloat; +using std::vector; + +NnetProducer::NnetProducer(std::shared_ptr nnet, + std::shared_ptr frontend, + float blank_threshold) + : nnet_(nnet), frontend_(frontend), blank_threshold_(blank_threshold) { + Reset(); +} + +void NnetProducer::Accept(const std::vector& inputs) { + frontend_->Accept(inputs); +} + +void NnetProducer::Acceptlikelihood( + const kaldi::Matrix& likelihood) { + std::vector prob; + prob.resize(likelihood.NumCols()); + for (size_t idx = 0; idx < likelihood.NumRows(); ++idx) { + for (size_t col = 0; col < likelihood.NumCols(); ++col) { + prob[col] = likelihood(idx, col); + } + cache_.push_back(prob); + } +} + +bool NnetProducer::Read(std::vector* nnet_prob) { + bool flag = cache_.pop(nnet_prob); + return flag; +} + +bool NnetProducer::Compute() { + vector features; + if (frontend_ == NULL || frontend_->Read(&features) == false) { + // no feat or frontend_ not init. + if (frontend_->IsFinished() == true) { + finished_ = true; + } + return false; + } + CHECK_GE(frontend_->Dim(), 0); + VLOG(1) << "Forward in " << features.size() / frontend_->Dim() << " feats."; + + NnetOut out; + nnet_->FeedForward(features, frontend_->Dim(), &out); + int32& vocab_dim = out.vocab_dim; + size_t nframes = out.logprobs.size() / vocab_dim; + VLOG(1) << "Forward out " << nframes << " decoder frames."; + for (size_t idx = 0; idx < nframes; ++idx) { + std::vector logprob( + out.logprobs.data() + idx * vocab_dim, + out.logprobs.data() + (idx + 1) * vocab_dim); + // process blank prob + float blank_prob = std::exp(logprob[0]); + if (blank_prob > blank_threshold_) { + last_frame_logprob_ = logprob; + is_last_frame_skip_ = true; + continue; + } else { + int cur_max = std::max(logprob.begin(), logprob.end()) - logprob.begin(); + if (cur_max == last_max_elem_ && cur_max != 0 && is_last_frame_skip_) { + cache_.push_back(last_frame_logprob_); + last_max_elem_ = cur_max; + } + last_max_elem_ = cur_max; + is_last_frame_skip_ = false; + cache_.push_back(logprob); + } + } + return true; +} + +void NnetProducer::AttentionRescoring(const std::vector>& hyps, + float reverse_weight, + std::vector* rescoring_score) { + nnet_->AttentionRescoring(hyps, reverse_weight, rescoring_score); +} + +} // namespace ppspeech diff --git a/runtime/engine/asr/nnet/nnet_producer.h b/runtime/engine/asr/nnet/nnet_producer.h new file mode 100644 index 00000000..21aee067 --- /dev/null +++ b/runtime/engine/asr/nnet/nnet_producer.h @@ -0,0 +1,77 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" +#include "base/safe_queue.h" +#include "frontend/frontend_itf.h" +#include "nnet/nnet_itf.h" + +namespace ppspeech { + +class NnetProducer { + public: + explicit NnetProducer(std::shared_ptr nnet, + std::shared_ptr frontend, + float blank_threshold); + // Feed feats or waves + void Accept(const std::vector& inputs); + + void Acceptlikelihood(const kaldi::Matrix& likelihood); + + // nnet + bool Read(std::vector* nnet_prob); + + bool Empty() const { return cache_.empty(); } + + void SetInputFinished() { + LOG(INFO) << "set finished"; + frontend_->SetFinished(); + } + + // the compute thread exit + bool IsFinished() const { + return (frontend_->IsFinished() && finished_); + } + + ~NnetProducer() {} + + void Reset() { + if (frontend_ != NULL) frontend_->Reset(); + if (nnet_ != NULL) nnet_->Reset(); + cache_.clear(); + finished_ = false; + } + + void AttentionRescoring(const std::vector>& hyps, + float reverse_weight, + std::vector* rescoring_score); + + bool Compute(); + private: + + std::shared_ptr frontend_; + std::shared_ptr nnet_; + SafeQueue> cache_; + std::vector last_frame_logprob_; + bool is_last_frame_skip_ = false; + int last_max_elem_ = -1; + float blank_threshold_ = 0.0; + bool finished_; + + DISALLOW_COPY_AND_ASSIGN(NnetProducer); +}; + +} // namespace ppspeech diff --git a/speechx/speechx/nnet/u2_nnet.cc b/runtime/engine/asr/nnet/u2_nnet.cc similarity index 87% rename from speechx/speechx/nnet/u2_nnet.cc rename to runtime/engine/asr/nnet/u2_nnet.cc index 7707406a..9a09514e 100644 --- a/speechx/speechx/nnet/u2_nnet.cc +++ b/runtime/engine/asr/nnet/u2_nnet.cc @@ -17,12 +17,13 @@ // https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/asr_model.cc #include "nnet/u2_nnet.h" +#include -#ifdef USE_PROFILING +#ifdef WITH_PROFILING #include "paddle/fluid/platform/profiler.h" using paddle::platform::RecordEvent; using paddle::platform::TracerEventType; -#endif // end USE_PROFILING +#endif // end WITH_PROFILING namespace ppspeech { @@ -30,7 +31,7 @@ namespace ppspeech { void U2Nnet::LoadModel(const std::string& model_path_w_prefix) { paddle::jit::utils::InitKernelSignatureMap(); -#ifdef USE_GPU +#ifdef WITH_GPU dev_ = phi::GPUPlace(); #else dev_ = phi::CPUPlace(); @@ -62,12 +63,12 @@ void U2Nnet::LoadModel(const std::string& model_path_w_prefix) { } void U2Nnet::Warmup() { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event("warmup", TracerEventType::UserDefined, 1); #endif { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event( "warmup-encoder-ctc", TracerEventType::UserDefined, 1); #endif @@ -91,7 +92,7 @@ void U2Nnet::Warmup() { } { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event("warmup-decoder", TracerEventType::UserDefined, 1); #endif auto hyps = @@ -101,10 +102,10 @@ void U2Nnet::Warmup() { auto encoder_out = paddle::ones( {1, 20, 512}, paddle::DataType::FLOAT32, phi::CPUPlace()); - std::vector inputs{ + std::vector inputs{ hyps, hyps_lens, encoder_out}; - std::vector outputs = + std::vector outputs = forward_attention_decoder_(inputs); } @@ -118,27 +119,46 @@ U2Nnet::U2Nnet(const ModelOptions& opts) : opts_(opts) { // shallow copy U2Nnet::U2Nnet(const U2Nnet& other) { // copy meta - right_context_ = other.right_context_; - subsampling_rate_ = other.subsampling_rate_; - sos_ = other.sos_; - eos_ = other.eos_; - is_bidecoder_ = other.is_bidecoder_; chunk_size_ = other.chunk_size_; num_left_chunks_ = other.num_left_chunks_; - - forward_encoder_chunk_ = other.forward_encoder_chunk_; - forward_attention_decoder_ = other.forward_attention_decoder_; - ctc_activation_ = other.ctc_activation_; - offset_ = other.offset_; // copy model ptr - model_ = other.model_; + // model_ = other.model_->Clone(); + // hack, fix later + #ifdef WITH_GPU + dev_ = phi::GPUPlace(); + #else + dev_ = phi::CPUPlace(); + #endif + paddle::jit::Layer model = paddle::jit::Load(other.opts_.model_path, dev_); + model_ = std::make_shared(std::move(model)); + ctc_activation_ = model_->Function("ctc_activation"); + subsampling_rate_ = model_->Attribute("subsampling_rate"); + right_context_ = model_->Attribute("right_context"); + sos_ = model_->Attribute("sos_symbol"); + eos_ = model_->Attribute("eos_symbol"); + is_bidecoder_ = model_->Attribute("is_bidirectional_decoder"); + + forward_encoder_chunk_ = model_->Function("forward_encoder_chunk"); + forward_attention_decoder_ = model_->Function("forward_attention_decoder"); + ctc_activation_ = model_->Function("ctc_activation"); + CHECK(forward_encoder_chunk_.IsValid()); + CHECK(forward_attention_decoder_.IsValid()); + CHECK(ctc_activation_.IsValid()); + + LOG(INFO) << "Paddle Model Info: "; + LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_; + LOG(INFO) << "\tright context " << right_context_; + LOG(INFO) << "\tsos " << sos_; + LOG(INFO) << "\teos " << eos_; + LOG(INFO) << "\tis bidecoder " << is_bidecoder_ << std::endl; + // ignore inner states } -std::shared_ptr U2Nnet::Copy() const { +std::shared_ptr U2Nnet::Clone() const { auto asr_model = std::make_shared(*this); // reset inner state for new decoding asr_model->Reset(); @@ -154,6 +174,7 @@ void U2Nnet::Reset() { std::move(paddle::zeros({0, 0, 0, 0}, paddle::DataType::FLOAT32)); encoder_outs_.clear(); + VLOG(1) << "FeedForward cost: " << cost_time_ << " sec. "; VLOG(3) << "u2nnet reset"; } @@ -165,23 +186,18 @@ void U2Nnet::FeedEncoderOuts(const paddle::Tensor& encoder_out) { } -void U2Nnet::FeedForward(const kaldi::Vector& features, +void U2Nnet::FeedForward(const std::vector& features, const int32& feature_dim, NnetOut* out) { kaldi::Timer timer; - std::vector chunk_feats(features.Data(), - features.Data() + features.Dim()); std::vector ctc_probs; ForwardEncoderChunkImpl( - chunk_feats, feature_dim, &ctc_probs, &out->vocab_dim); - - out->logprobs.Resize(ctc_probs.size(), kaldi::kSetZero); - std::memcpy(out->logprobs.Data(), - ctc_probs.data(), - ctc_probs.size() * sizeof(kaldi::BaseFloat)); - VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. " - << chunk_feats.size() / feature_dim << " frames."; + features, feature_dim, &out->logprobs, &out->vocab_dim); + float forward_chunk_time = timer.Elapsed(); + VLOG(1) << "FeedForward cost: " << forward_chunk_time << " sec. " + << features.size() / feature_dim << " frames."; + cost_time_ += forward_chunk_time; } @@ -190,7 +206,7 @@ void U2Nnet::ForwardEncoderChunkImpl( const int32& feat_dim, std::vector* out_prob, int32* vocab_dim) { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event( "ForwardEncoderChunkImpl", TracerEventType::UserDefined, 1); #endif @@ -210,7 +226,7 @@ void U2Nnet::ForwardEncoderChunkImpl( // not cache feature in nnet CHECK_EQ(cached_feats_.size(), 0); - // CHECK_EQ(std::is_same::value, true); + CHECK_EQ((std::is_same::value), true); std::memcpy(feats_ptr, chunk_feats.data(), chunk_feats.size() * sizeof(kaldi::BaseFloat)); @@ -218,7 +234,7 @@ void U2Nnet::ForwardEncoderChunkImpl( VLOG(3) << "feats shape: " << feats.shape()[0] << ", " << feats.shape()[1] << ", " << feats.shape()[2]; -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("feat", std::ios_base::app | std::ios_base::out); path << offset_; @@ -237,7 +253,7 @@ void U2Nnet::ForwardEncoderChunkImpl( #endif // Endocer chunk forward -#ifdef USE_GPU +#ifdef WITH_GPU feats = feats.copy_to(paddle::GPUPlace(), /*blocking*/ false); att_cache_ = att_cache_.copy_to(paddle::GPUPlace()), /*blocking*/ false; cnn_cache_ = cnn_cache_.copy_to(Paddle::GPUPlace(), /*blocking*/ false); @@ -254,7 +270,7 @@ void U2Nnet::ForwardEncoderChunkImpl( std::vector outputs = forward_encoder_chunk_(inputs); CHECK_EQ(outputs.size(), 3); -#ifdef USE_GPU +#ifdef WITH_GPU paddle::Tensor chunk_out = outputs[0].copy_to(paddle::CPUPlace()); att_cache_ = outputs[1].copy_to(paddle::CPUPlace()); cnn_cache_ = outputs[2].copy_to(paddle::CPUPlace()); @@ -264,7 +280,7 @@ void U2Nnet::ForwardEncoderChunkImpl( cnn_cache_ = outputs[2]; #endif -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_logits", std::ios_base::app | std::ios_base::out); @@ -294,7 +310,7 @@ void U2Nnet::ForwardEncoderChunkImpl( encoder_outs_.push_back(chunk_out); VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size(); -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_logits_list", std::ios_base::app | std::ios_base::out); @@ -313,7 +329,7 @@ void U2Nnet::ForwardEncoderChunkImpl( } #endif // end TEST_DEBUG -#ifdef USE_GPU +#ifdef WITH_GPU #error "Not implementation." @@ -327,7 +343,7 @@ void U2Nnet::ForwardEncoderChunkImpl( CHECK_EQ(outputs.size(), 1); paddle::Tensor ctc_log_probs = outputs[0]; -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_logprob", std::ios_base::app | std::ios_base::out); @@ -349,7 +365,7 @@ void U2Nnet::ForwardEncoderChunkImpl( } #endif // end TEST_DEBUG -#endif // end USE_GPU +#endif // end WITH_GPU // Copy to output, (B=1,T,D) std::vector ctc_log_probs_shape = ctc_log_probs.shape(); @@ -366,7 +382,7 @@ void U2Nnet::ForwardEncoderChunkImpl( std::memcpy( out_prob->data(), ctc_log_probs_ptr, T * D * sizeof(kaldi::BaseFloat)); -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_logits_list_ctc", std::ios_base::app | std::ios_base::out); @@ -415,7 +431,7 @@ float U2Nnet::ComputePathScore(const paddle::Tensor& prob, void U2Nnet::AttentionRescoring(const std::vector>& hyps, float reverse_weight, std::vector* rescoring_score) { -#ifdef USE_PROFILING +#ifdef WITH_PROFILING RecordEvent event("AttentionRescoring", TracerEventType::UserDefined, 1); #endif CHECK(rescoring_score != nullptr); @@ -457,7 +473,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } } -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_logits_concat", std::ios_base::app | std::ios_base::out); @@ -481,7 +497,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, paddle::Tensor encoder_out = paddle::concat(encoder_outs_, 1); VLOG(2) << "encoder_outs_ size: " << encoder_outs_.size(); -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_out0", std::ios_base::app | std::ios_base::out); @@ -500,7 +516,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } #endif // end TEST_DEBUG -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("encoder_out", std::ios_base::app | std::ios_base::out); @@ -519,7 +535,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } #endif // end TEST_DEBUG - std::vector inputs{ + std::vector inputs{ hyps_tensor, hyps_lens, encoder_out}; std::vector outputs = forward_attention_decoder_(inputs); CHECK_EQ(outputs.size(), 2); @@ -531,7 +547,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, CHECK_EQ(probs_shape[0], num_hyps); CHECK_EQ(probs_shape[1], max_hyps_len); -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("decoder_logprob", std::ios_base::app | std::ios_base::out); @@ -549,7 +565,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } #endif // end TEST_DEBUG -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("hyps_lens", std::ios_base::app | std::ios_base::out); @@ -565,7 +581,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } #endif // end TEST_DEBUG -#ifdef TEST_DEBUG +#ifdef PPS_DEBUG { std::stringstream path("hyps_tensor", std::ios_base::app | std::ios_base::out); @@ -590,7 +606,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, } else { // dump r_probs CHECK_EQ(r_probs_shape.size(), 1); - CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0]; + //CHECK_EQ(r_probs_shape[0], 1) << r_probs_shape[0]; } // compute rescoring score @@ -600,15 +616,15 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, VLOG(2) << "split prob: " << probs_v.size() << " " << probs_v[0].shape().size() << " 0: " << probs_v[0].shape()[0] << ", " << probs_v[0].shape()[1] << ", " << probs_v[0].shape()[2]; - CHECK(static_cast(probs_v.size()) == num_hyps) - << ": is " << probs_v.size() << " expect: " << num_hyps; + //CHECK(static_cast(probs_v.size()) == num_hyps) + // << ": is " << probs_v.size() << " expect: " << num_hyps; std::vector r_probs_v; if (is_bidecoder_ && reverse_weight > 0) { r_probs_v = paddle::experimental::split_with_num(r_probs, num_hyps, 0); - CHECK(static_cast(r_probs_v.size()) == num_hyps) - << "r_probs_v size: is " << r_probs_v.size() - << " expect: " << num_hyps; + //CHECK(static_cast(r_probs_v.size()) == num_hyps) + // << "r_probs_v size: is " << r_probs_v.size() + // << " expect: " << num_hyps; } for (int i = 0; i < num_hyps; ++i) { @@ -638,7 +654,7 @@ void U2Nnet::AttentionRescoring(const std::vector>& hyps, void U2Nnet::EncoderOuts( - std::vector>* encoder_out) const { + std::vector>* encoder_out) const { // list of (B=1,T,D) int size = encoder_outs_.size(); VLOG(3) << "encoder_outs_ size: " << size; @@ -650,18 +666,18 @@ void U2Nnet::EncoderOuts( const int& B = shape[0]; const int& T = shape[1]; const int& D = shape[2]; - CHECK(B == 1) << "Only support batch one."; + //CHECK(B == 1) << "Only support batch one."; VLOG(3) << "encoder out " << i << " shape: (" << B << "," << T << "," << D << ")"; const float* this_tensor_ptr = item.data(); for (int j = 0; j < T; j++) { const float* cur = this_tensor_ptr + j * D; - kaldi::Vector out(D); - std::memcpy(out.Data(), cur, D * sizeof(kaldi::BaseFloat)); + std::vector out(D); + std::memcpy(out.data(), cur, D * sizeof(kaldi::BaseFloat)); encoder_out->emplace_back(out); } } } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/nnet/u2_nnet.h b/runtime/engine/asr/nnet/u2_nnet.h similarity index 91% rename from speechx/speechx/nnet/u2_nnet.h rename to runtime/engine/asr/nnet/u2_nnet.h index 23cc0ea3..dba5c55e 100644 --- a/speechx/speechx/nnet/u2_nnet.h +++ b/runtime/engine/asr/nnet/u2_nnet.h @@ -18,7 +18,7 @@ #pragma once #include "base/common.h" -#include "kaldi/matrix/kaldi-matrix.h" +#include "matrix/kaldi-matrix.h" #include "nnet/nnet_itf.h" #include "paddle/extension.h" #include "paddle/jit/all.h" @@ -42,7 +42,7 @@ class U2NnetBase : public NnetBase { num_left_chunks_ = num_left_chunks; } - virtual std::shared_ptr Copy() const = 0; + virtual std::shared_ptr Clone() const = 0; protected: virtual void ForwardEncoderChunkImpl( @@ -76,7 +76,7 @@ class U2Nnet : public U2NnetBase { explicit U2Nnet(const ModelOptions& opts); U2Nnet(const U2Nnet& other); - void FeedForward(const kaldi::Vector& features, + void FeedForward(const std::vector& features, const int32& feature_dim, NnetOut* out) override; @@ -91,7 +91,7 @@ class U2Nnet : public U2NnetBase { std::shared_ptr model() const { return model_; } - std::shared_ptr Copy() const override; + std::shared_ptr Clone() const override; void ForwardEncoderChunkImpl( const std::vector& chunk_feats, @@ -111,10 +111,10 @@ class U2Nnet : public U2NnetBase { void FeedEncoderOuts(const paddle::Tensor& encoder_out); void EncoderOuts( - std::vector>* encoder_out) const; + std::vector>* encoder_out) const; + ModelOptions opts_; // hack, fix later private: - ModelOptions opts_; phi::Place dev_; std::shared_ptr model_{nullptr}; @@ -127,6 +127,7 @@ class U2Nnet : public U2NnetBase { paddle::jit::Function forward_encoder_chunk_; paddle::jit::Function forward_attention_decoder_; paddle::jit::Function ctc_activation_; + float cost_time_ = 0.0; }; } // namespace ppspeech \ No newline at end of file diff --git a/speechx/speechx/nnet/u2_nnet_main.cc b/runtime/engine/asr/nnet/u2_nnet_main.cc similarity index 99% rename from speechx/speechx/nnet/u2_nnet_main.cc rename to runtime/engine/asr/nnet/u2_nnet_main.cc index 53fc5554..e60ae7e8 100644 --- a/speechx/speechx/nnet/u2_nnet_main.cc +++ b/runtime/engine/asr/nnet/u2_nnet_main.cc @@ -15,8 +15,8 @@ #include "base/common.h" #include "decoder/param.h" -#include "frontend/audio/assembler.h" -#include "frontend/audio/data_cache.h" +#include "frontend/assembler.h" +#include "frontend/data_cache.h" #include "kaldi/util/table-types.h" #include "nnet/decodable.h" #include "nnet/u2_nnet.h" diff --git a/runtime/engine/asr/nnet/u2_nnet_thread_main.cc b/runtime/engine/asr/nnet/u2_nnet_thread_main.cc new file mode 100644 index 00000000..008dbb1e --- /dev/null +++ b/runtime/engine/asr/nnet/u2_nnet_thread_main.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef USE_ONNX + #include "nnet/u2_nnet.h" +#else + #include "nnet/u2_onnx_nnet.h" +#endif +#include "base/common.h" +#include "decoder/param.h" +#include "frontend/feature_pipeline.h" +#include "frontend/wave-reader.h" +#include "kaldi/util/table-types.h" +#include "nnet/decodable.h" +#include "nnet/nnet_producer.h" +#include "nnet/u2_nnet.h" + +DEFINE_string(wav_rspecifier, "", "test wav rspecifier"); +DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); + +using kaldi::BaseFloat; +using kaldi::Matrix; +using std::vector; + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int32 num_done = 0, num_err = 0; + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + + CHECK_GT(FLAGS_wav_rspecifier.size(), 0); + CHECK_GT(FLAGS_nnet_prob_wspecifier.size(), 0); + CHECK_GT(FLAGS_model_path.size(), 0); + LOG(INFO) << "input rspecifier: " << FLAGS_wav_rspecifier; + LOG(INFO) << "output wspecifier: " << FLAGS_nnet_prob_wspecifier; + LOG(INFO) << "model path: " << FLAGS_model_path; + + kaldi::SequentialTableReader wav_reader( + FLAGS_wav_rspecifier); + kaldi::BaseFloatMatrixWriter nnet_out_writer(FLAGS_nnet_prob_wspecifier); + + ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags(); + ppspeech::FeaturePipelineOptions feature_opts = + ppspeech::FeaturePipelineOptions::InitFromFlags(); + feature_opts.assembler_opts.fill_zero = false; + +#ifndef USE_ONNX + std::shared_ptr nnet(new ppspeech::U2Nnet(model_opts)); +#else + std::shared_ptr nnet(new ppspeech::U2OnnxNnet(model_opts)); +#endif + std::shared_ptr feature_pipeline( + new ppspeech::FeaturePipeline(feature_opts)); + std::shared_ptr nnet_producer( + new ppspeech::NnetProducer(nnet, feature_pipeline)); + kaldi::Timer timer; + float tot_wav_duration = 0; + + for (; !wav_reader.Done(); wav_reader.Next()) { + std::string utt = wav_reader.Key(); + const kaldi::WaveData& wave_data = wav_reader.Value(); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + + nnet_producer->Accept(wav_chunk); + if (cur_chunk_size < chunk_sample_size) { + nnet_producer->SetInputFinished(); + } + + // no overlap + sample_offset += cur_chunk_size; + } + CHECK(sample_offset == tot_samples); + + std::vector> prob_vec; + while (1) { + std::vector logprobs; + bool isok = nnet_producer->Read(&logprobs); + if (nnet_producer->IsFinished()) break; + if (isok == false) continue; + prob_vec.push_back(logprobs); + } + { + // writer nnet output + kaldi::MatrixIndexT nrow = prob_vec.size(); + kaldi::MatrixIndexT ncol = prob_vec[0].size(); + LOG(INFO) << "nnet out shape: " << nrow << ", " << ncol; + kaldi::Matrix nnet_out(nrow, ncol); + for (int32 row_idx = 0; row_idx < nrow; ++row_idx) { + for (int32 col_idx = 0; col_idx < ncol; ++col_idx) { + nnet_out(row_idx, col_idx) = prob_vec[row_idx][col_idx]; + } + } + nnet_out_writer.Write(utt, nnet_out); + } + nnet_producer->Reset(); + } + + nnet_producer->Wait(); + double elapsed = timer.Elapsed(); + LOG(INFO) << "Program cost:" << elapsed << " sec"; + + LOG(INFO) << "Done " << num_done << " utterances, " << num_err + << " with errors."; + return (num_done != 0 ? 0 : 1); +} diff --git a/runtime/engine/asr/nnet/u2_onnx_nnet.cc b/runtime/engine/asr/nnet/u2_onnx_nnet.cc new file mode 100644 index 00000000..d5e2fdb6 --- /dev/null +++ b/runtime/engine/asr/nnet/u2_onnx_nnet.cc @@ -0,0 +1,414 @@ +// Copyright 2022 Horizon Robotics. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// modified from +// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.cc + +#include "nnet/u2_onnx_nnet.h" +#include "common/base/config.h" + +namespace ppspeech { + +void U2OnnxNnet::LoadModel(const std::string& model_dir) { + std::string encoder_onnx_path = model_dir + "/encoder.onnx"; + std::string rescore_onnx_path = model_dir + "/decoder.onnx"; + std::string ctc_onnx_path = model_dir + "/ctc.onnx"; + std::string param_path = model_dir + "/param.onnx"; + // 1. Load sessions + try { + encoder_ = std::make_shared(); + ctc_ = std::make_shared(); + rescore_ = std::make_shared(); + fastdeploy::RuntimeOption runtime_option; + runtime_option.UseOrtBackend(); + runtime_option.UseCpu(); + runtime_option.SetCpuThreadNum(1); + runtime_option.SetModelPath(encoder_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX); + assert(encoder_->Init(runtime_option)); + runtime_option.SetModelPath(rescore_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX); + assert(rescore_->Init(runtime_option)); + runtime_option.SetModelPath(ctc_onnx_path.c_str(), "", fastdeploy::ModelFormat::ONNX); + assert(ctc_->Init(runtime_option)); + } catch (std::exception const& e) { + LOG(ERROR) << "error when load onnx model: " << e.what(); + exit(0); + } + + Config conf(param_path); + encoder_output_size_ = conf.Read("output_size", encoder_output_size_); + num_blocks_ = conf.Read("num_blocks", num_blocks_); + head_ = conf.Read("head", head_); + cnn_module_kernel_ = conf.Read("cnn_module_kernel", cnn_module_kernel_); + subsampling_rate_ = conf.Read("subsampling_rate", subsampling_rate_); + right_context_ = conf.Read("right_context", right_context_); + sos_= conf.Read("sos_symbol", sos_); + eos_= conf.Read("eos_symbol", eos_); + is_bidecoder_= conf.Read("is_bidirectional_decoder", is_bidecoder_); + chunk_size_= conf.Read("chunk_size", chunk_size_); + num_left_chunks_ = conf.Read("left_chunks", num_left_chunks_); + + LOG(INFO) << "Onnx Model Info:"; + LOG(INFO) << "\tencoder_output_size " << encoder_output_size_; + LOG(INFO) << "\tnum_blocks " << num_blocks_; + LOG(INFO) << "\thead " << head_; + LOG(INFO) << "\tcnn_module_kernel " << cnn_module_kernel_; + LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_; + LOG(INFO) << "\tright_context " << right_context_; + LOG(INFO) << "\tsos " << sos_; + LOG(INFO) << "\teos " << eos_; + LOG(INFO) << "\tis bidirectional decoder " << is_bidecoder_; + LOG(INFO) << "\tchunk_size " << chunk_size_; + LOG(INFO) << "\tnum_left_chunks " << num_left_chunks_; + + // 3. Read model nodes + LOG(INFO) << "Onnx Encoder:"; + GetInputOutputInfo(encoder_, &encoder_in_names_, &encoder_out_names_); + LOG(INFO) << "Onnx CTC:"; + GetInputOutputInfo(ctc_, &ctc_in_names_, &ctc_out_names_); + LOG(INFO) << "Onnx Rescore:"; + GetInputOutputInfo(rescore_, &rescore_in_names_, &rescore_out_names_); +} + +U2OnnxNnet::U2OnnxNnet(const ModelOptions& opts) : opts_(opts) { + LoadModel(opts_.model_path); +} + +// shallow copy +U2OnnxNnet::U2OnnxNnet(const U2OnnxNnet& other) { + // metadatas + encoder_output_size_ = other.encoder_output_size_; + num_blocks_ = other.num_blocks_; + head_ = other.head_; + cnn_module_kernel_ = other.cnn_module_kernel_; + right_context_ = other.right_context_; + subsampling_rate_ = other.subsampling_rate_; + sos_ = other.sos_; + eos_ = other.eos_; + is_bidecoder_ = other.is_bidecoder_; + chunk_size_ = other.chunk_size_; + num_left_chunks_ = other.num_left_chunks_; + offset_ = other.offset_; + + // session + encoder_ = other.encoder_; + ctc_ = other.ctc_; + rescore_ = other.rescore_; + + // node names + encoder_in_names_ = other.encoder_in_names_; + encoder_out_names_ = other.encoder_out_names_; + ctc_in_names_ = other.ctc_in_names_; + ctc_out_names_ = other.ctc_out_names_; + rescore_in_names_ = other.rescore_in_names_; + rescore_out_names_ = other.rescore_out_names_; +} + +void U2OnnxNnet::GetInputOutputInfo(const std::shared_ptr& runtime, + std::vector* in_names, std::vector* out_names) { + std::vector inputs_info = runtime->GetInputInfos(); + (*in_names).resize(inputs_info.size()); + for (int i = 0; i < inputs_info.size(); ++i){ + fastdeploy::TensorInfo info = inputs_info[i]; + + std::stringstream shape; + for(int j = 0; j < info.shape.size(); ++j){ + shape << info.shape[j]; + shape << " "; + } + LOG(INFO) << "\tInput " << i << " : name=" << info.name << " type=" << info.dtype + << " dims=" << shape.str(); + (*in_names)[i] = info.name; + } + std::vector outputs_info = runtime->GetOutputInfos(); + (*out_names).resize(outputs_info.size()); + for (int i = 0; i < outputs_info.size(); ++i){ + fastdeploy::TensorInfo info = outputs_info[i]; + + std::stringstream shape; + for(int j = 0; j < info.shape.size(); ++j){ + shape << info.shape[j]; + shape << " "; + } + LOG(INFO) << "\tOutput " << i << " : name=" << info.name << " type=" << info.dtype + << " dims=" << shape.str(); + (*out_names)[i] = info.name; + } +} + +std::shared_ptr U2OnnxNnet::Clone() const { + auto asr_model = std::make_shared(*this); + // reset inner state for new decoding + asr_model->Reset(); + return asr_model; +} + +void U2OnnxNnet::Reset() { + offset_ = 0; + encoder_outs_.clear(); + cached_feats_.clear(); + // Reset att_cache + if (num_left_chunks_ > 0) { + int required_cache_size = chunk_size_ * num_left_chunks_; + offset_ = required_cache_size; + att_cache_.resize(num_blocks_ * head_ * required_cache_size * + encoder_output_size_ / head_ * 2, + 0.0); + const std::vector att_cache_shape = {num_blocks_, head_, required_cache_size, + encoder_output_size_ / head_ * 2}; + att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data()); + } else { + att_cache_.resize(0, 0.0); + const std::vector att_cache_shape = {num_blocks_, head_, 0, + encoder_output_size_ / head_ * 2}; + att_cache_ort_.SetExternalData(att_cache_shape, fastdeploy::FDDataType::FP32, att_cache_.data()); + } + + // Reset cnn_cache + cnn_cache_.resize( + num_blocks_ * encoder_output_size_ * (cnn_module_kernel_ - 1), 0.0); + const std::vector cnn_cache_shape = {num_blocks_, 1, encoder_output_size_, + cnn_module_kernel_ - 1}; + cnn_cache_ort_.SetExternalData(cnn_cache_shape, fastdeploy::FDDataType::FP32, cnn_cache_.data()); +} + +void U2OnnxNnet::FeedForward(const std::vector& features, + const int32& feature_dim, + NnetOut* out) { + kaldi::Timer timer; + + std::vector ctc_probs; + ForwardEncoderChunkImpl( + features, feature_dim, &out->logprobs, &out->vocab_dim); + VLOG(1) << "FeedForward cost: " << timer.Elapsed() << " sec. " + << features.size() / feature_dim << " frames."; +} + +void U2OnnxNnet::ForwardEncoderChunkImpl( + const std::vector& chunk_feats, + const int32& feat_dim, + std::vector* out_prob, + int32* vocab_dim) { + + // 1. Prepare onnx required data, splice cached_feature_ and chunk_feats + // chunk + int num_frames = chunk_feats.size() / feat_dim; + VLOG(3) << "num_frames: " << num_frames; + VLOG(3) << "feat_dim: " << feat_dim; + const int feature_dim = feat_dim; + std::vector feats; + feats.insert(feats.end(), chunk_feats.begin(), chunk_feats.end()); + fastdeploy::FDTensor feats_ort; + const std::vector feats_shape = {1, num_frames, feature_dim}; + feats_ort.SetExternalData(feats_shape, fastdeploy::FDDataType::FP32, feats.data()); + + // offset + int64_t offset_int64 = static_cast(offset_); + fastdeploy::FDTensor offset_ort; + offset_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &offset_int64); + + // required_cache_size + int64_t required_cache_size = chunk_size_ * num_left_chunks_; + fastdeploy::FDTensor required_cache_size_ort(""); + required_cache_size_ort.SetExternalData({}, fastdeploy::FDDataType::INT64, &required_cache_size); + + // att_mask + fastdeploy::FDTensor att_mask_ort; + std::vector att_mask(required_cache_size + chunk_size_, 1); + if (num_left_chunks_ > 0) { + int chunk_idx = offset_ / chunk_size_ - num_left_chunks_; + if (chunk_idx < num_left_chunks_) { + for (int i = 0; i < (num_left_chunks_ - chunk_idx) * chunk_size_; ++i) { + att_mask[i] = 0; + } + } + const std::vector att_mask_shape = {1, 1, required_cache_size + chunk_size_}; + att_mask_ort.SetExternalData(att_mask_shape, fastdeploy::FDDataType::BOOL, reinterpret_cast(att_mask.data())); + } + + // 2. Encoder chunk forward + std::vector inputs(encoder_in_names_.size()); + for (int i = 0; i < encoder_in_names_.size(); ++i) { + std::string name = encoder_in_names_[i]; + if (!strcmp(name.data(), "chunk")) { + inputs[i] = std::move(feats_ort); + inputs[i].name = "chunk"; + } else if (!strcmp(name.data(), "offset")) { + inputs[i] = std::move(offset_ort); + inputs[i].name = "offset"; + } else if (!strcmp(name.data(), "required_cache_size")) { + inputs[i] = std::move(required_cache_size_ort); + inputs[i].name = "required_cache_size"; + } else if (!strcmp(name.data(), "att_cache")) { + inputs[i] = std::move(att_cache_ort_); + inputs[i].name = "att_cache"; + } else if (!strcmp(name.data(), "cnn_cache")) { + inputs[i] = std::move(cnn_cache_ort_); + inputs[i].name = "cnn_cache"; + } else if (!strcmp(name.data(), "att_mask")) { + inputs[i] = std::move(att_mask_ort); + inputs[i].name = "att_mask"; + } + } + + std::vector ort_outputs; + assert(encoder_->Infer(inputs, &ort_outputs)); + + offset_ += static_cast(ort_outputs[0].shape[1]); + att_cache_ort_ = std::move(ort_outputs[1]); + cnn_cache_ort_ = std::move(ort_outputs[2]); + + std::vector ctc_inputs; + ctc_inputs.emplace_back(std::move(ort_outputs[0])); + // ctc_inputs[0] = std::move(ort_outputs[0]); + ctc_inputs[0].name = ctc_in_names_[0]; + + std::vector ctc_ort_outputs; + assert(ctc_->Infer(ctc_inputs, &ctc_ort_outputs)); + encoder_outs_.emplace_back(std::move(ctc_inputs[0])); // ***** + + float* logp_data = reinterpret_cast(ctc_ort_outputs[0].Data()); + + // Copy to output, (B=1,T,D) + std::vector ctc_log_probs_shape = ctc_ort_outputs[0].shape; + CHECK_EQ(ctc_log_probs_shape.size(), 3); + int B = ctc_log_probs_shape[0]; + CHECK_EQ(B, 1); + int T = ctc_log_probs_shape[1]; + int D = ctc_log_probs_shape[2]; + *vocab_dim = D; + + out_prob->resize(T * D); + std::memcpy( + out_prob->data(), logp_data, T * D * sizeof(kaldi::BaseFloat)); + return; +} + +float U2OnnxNnet::ComputeAttentionScore(const float* prob, + const std::vector& hyp, int eos, + int decode_out_len) { + float score = 0.0f; + for (size_t j = 0; j < hyp.size(); ++j) { + score += *(prob + j * decode_out_len + hyp[j]); + } + score += *(prob + hyp.size() * decode_out_len + eos); + return score; +} + +void U2OnnxNnet::AttentionRescoring(const std::vector>& hyps, + float reverse_weight, + std::vector* rescoring_score) { + CHECK(rescoring_score != nullptr); + int num_hyps = hyps.size(); + rescoring_score->resize(num_hyps, 0.0f); + + if (num_hyps == 0) { + return; + } + // No encoder output + if (encoder_outs_.size() == 0) { + return; + } + + std::vector hyps_lens; + int max_hyps_len = 0; + for (size_t i = 0; i < num_hyps; ++i) { + int length = hyps[i].size() + 1; + max_hyps_len = std::max(length, max_hyps_len); + hyps_lens.emplace_back(static_cast(length)); + } + + std::vector rescore_input; + int encoder_len = 0; + for (int i = 0; i < encoder_outs_.size(); i++) { + float* encoder_outs_data = reinterpret_cast(encoder_outs_[i].Data()); + for (int j = 0; j < encoder_outs_[i].Numel(); j++) { + rescore_input.emplace_back(encoder_outs_data[j]); + } + encoder_len += encoder_outs_[i].shape[1]; + } + + std::vector hyps_pad; + + for (size_t i = 0; i < num_hyps; ++i) { + const std::vector& hyp = hyps[i]; + hyps_pad.emplace_back(sos_); + size_t j = 0; + for (; j < hyp.size(); ++j) { + hyps_pad.emplace_back(hyp[j]); + } + if (j == max_hyps_len - 1) { + continue; + } + for (; j < max_hyps_len - 1; ++j) { + hyps_pad.emplace_back(0); + } + } + + const std::vector hyps_pad_shape = {num_hyps, max_hyps_len}; + const std::vector hyps_lens_shape = {num_hyps}; + const std::vector decode_input_shape = {1, encoder_len, encoder_output_size_}; + + fastdeploy::FDTensor hyps_pad_tensor_; + hyps_pad_tensor_.SetExternalData(hyps_pad_shape, fastdeploy::FDDataType::INT64, hyps_pad.data()); + fastdeploy::FDTensor hyps_lens_tensor_; + hyps_lens_tensor_.SetExternalData(hyps_lens_shape, fastdeploy::FDDataType::INT64, hyps_lens.data()); + fastdeploy::FDTensor decode_input_tensor_; + decode_input_tensor_.SetExternalData(decode_input_shape, fastdeploy::FDDataType::FP32, rescore_input.data()); + + std::vector rescore_inputs(3); + + rescore_inputs[0] = std::move(hyps_pad_tensor_); + rescore_inputs[0].name = rescore_in_names_[0]; + rescore_inputs[1] = std::move(hyps_lens_tensor_); + rescore_inputs[1].name = rescore_in_names_[1]; + rescore_inputs[2] = std::move(decode_input_tensor_); + rescore_inputs[2].name = rescore_in_names_[2]; + + std::vector rescore_outputs; + assert(rescore_->Infer(rescore_inputs, &rescore_outputs)); + + float* decoder_outs_data = reinterpret_cast(rescore_outputs[0].Data()); + float* r_decoder_outs_data = reinterpret_cast(rescore_outputs[1].Data()); + + int decode_out_len = rescore_outputs[0].shape[2]; + + for (size_t i = 0; i < num_hyps; ++i) { + const std::vector& hyp = hyps[i]; + float score = 0.0f; + // left to right decoder score + score = ComputeAttentionScore( + decoder_outs_data + max_hyps_len * decode_out_len * i, hyp, eos_, + decode_out_len); + // Optional: Used for right to left score + float r_score = 0.0f; + if (is_bidecoder_ && reverse_weight > 0) { + std::vector r_hyp(hyp.size()); + std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin()); + // right to left decoder score + r_score = ComputeAttentionScore( + r_decoder_outs_data + max_hyps_len * decode_out_len * i, r_hyp, eos_, + decode_out_len); + } + // combined left-to-right and right-to-left score + (*rescoring_score)[i] = + score * (1 - reverse_weight) + r_score * reverse_weight; + } +} + +void U2OnnxNnet::EncoderOuts( + std::vector>* encoder_out) const { +} + +} //namepace ppspeech \ No newline at end of file diff --git a/runtime/engine/asr/nnet/u2_onnx_nnet.h b/runtime/engine/asr/nnet/u2_onnx_nnet.h new file mode 100644 index 00000000..6e9126b0 --- /dev/null +++ b/runtime/engine/asr/nnet/u2_onnx_nnet.h @@ -0,0 +1,97 @@ +// Copyright 2022 Horizon Robotics. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// modified from +// https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/onnx_asr_model.h + +#pragma once + +#include "base/common.h" +#include "matrix/kaldi-matrix.h" +#include "nnet/nnet_itf.h" +#include "nnet/u2_nnet.h" + +#include "fastdeploy/runtime.h" + +namespace ppspeech { + +class U2OnnxNnet : public U2NnetBase { + + public: + explicit U2OnnxNnet(const ModelOptions& opts); + U2OnnxNnet(const U2OnnxNnet& other); + + void FeedForward(const std::vector& features, + const int32& feature_dim, + NnetOut* out) override; + + void Reset() override; + + bool IsLogProb() override { return true; } + + void Dim(); + + void LoadModel(const std::string& model_dir); + + std::shared_ptr Clone() const override; + + void ForwardEncoderChunkImpl( + const std::vector& chunk_feats, + const int32& feat_dim, + std::vector* ctc_probs, + int32* vocab_dim) override; + + float ComputeAttentionScore(const float* prob, const std::vector& hyp, + int eos, int decode_out_len); + + void AttentionRescoring(const std::vector>& hyps, + float reverse_weight, + std::vector* rescoring_score) override; + + void EncoderOuts( + std::vector>* encoder_out) const; + + void GetInputOutputInfo(const std::shared_ptr& runtime, + std::vector* in_names, + std::vector* out_names); + private: + ModelOptions opts_; + + int encoder_output_size_ = 0; + int num_blocks_ = 0; + int cnn_module_kernel_ = 0; + int head_ = 0; + + // sessions + std::shared_ptr encoder_ = nullptr; + std::shared_ptr rescore_ = nullptr; + std::shared_ptr ctc_ = nullptr; + + + // node names + std::vector encoder_in_names_, encoder_out_names_; + std::vector ctc_in_names_, ctc_out_names_; + std::vector rescore_in_names_, rescore_out_names_; + + // caches + fastdeploy::FDTensor att_cache_ort_; + fastdeploy::FDTensor cnn_cache_ort_; + std::vector encoder_outs_; + + std::vector att_cache_; + std::vector cnn_cache_; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/CMakeLists.txt b/runtime/engine/asr/recognizer/CMakeLists.txt new file mode 100644 index 00000000..e8c86505 --- /dev/null +++ b/runtime/engine/asr/recognizer/CMakeLists.txt @@ -0,0 +1,26 @@ +set(srcs) + +list(APPEND srcs + recognizer_controller.cc + recognizer_controller_impl.cc + recognizer_instance.cc + recognizer.cc +) + +add_library(recognizer STATIC ${srcs}) +target_link_libraries(recognizer PUBLIC decoder) + +set(TEST_BINS + recognizer_batch_main + recognizer_batch_main2 + recognizer_main +) + +foreach(bin_name IN LISTS TEST_BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + target_link_libraries(${bin_name} recognizer nnet decoder fst utils gflags glog kaldi-base kaldi-matrix kaldi-util) + target_compile_options(${bin_name} PRIVATE ${PADDLE_COMPILE_FLAGS}) + target_include_directories(${bin_name} PRIVATE ${pybind11_INCLUDE_DIRS} ${PROJECT_SOURCE_DIR}) + target_link_libraries(${bin_name} ${PYTHON_LIBRARIES} ${PADDLE_LINK_FLAGS} -ldl) +endforeach() diff --git a/runtime/engine/asr/recognizer/recognizer.cc b/runtime/engine/asr/recognizer/recognizer.cc new file mode 100644 index 00000000..3a95bcc8 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "recognizer/recognizer.h" +#include "recognizer/recognizer_instance.h" + +bool InitRecognizer(const std::string& model_file, + const std::string& word_symbol_table_file, + const std::string& fst_file, + int num_instance) { + return ppspeech::RecognizerInstance::GetInstance().Init(model_file, + word_symbol_table_file, + fst_file, + num_instance); +} + +int GetRecognizerInstanceId() { + return ppspeech::RecognizerInstance::GetInstance().GetRecognizerInstanceId(); +} + +void InitDecoder(int instance_id) { + return ppspeech::RecognizerInstance::GetInstance().InitDecoder(instance_id); +} + +void AcceptData(const std::vector& waves, int instance_id) { + return ppspeech::RecognizerInstance::GetInstance().Accept(waves, instance_id); +} + +void SetInputFinished(int instance_id) { + return ppspeech::RecognizerInstance::GetInstance().SetInputFinished(instance_id); +} + +std::string GetFinalResult(int instance_id) { + return ppspeech::RecognizerInstance::GetInstance().GetResult(instance_id); +} \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer.h b/runtime/engine/asr/recognizer/recognizer.h new file mode 100644 index 00000000..bd7fb129 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer.h @@ -0,0 +1,28 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +bool InitRecognizer(const std::string& model_file, + const std::string& word_symbol_table_file, + const std::string& fst_file, + int num_instance); +int GetRecognizerInstanceId(); +void InitDecoder(int instance_id); +void AcceptData(const std::vector& waves, int instance_id); +void SetInputFinished(int instance_id); +std::string GetFinalResult(int instance_id); \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer_batch_main.cc b/runtime/engine/asr/recognizer/recognizer_batch_main.cc new file mode 100644 index 00000000..0cc34f26 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_batch_main.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/base/thread_pool.h" +#include "common/utils/file_utils.h" +#include "common/utils/strings.h" +#include "decoder/param.h" +#include "frontend/wave-reader.h" +#include "kaldi/util/table-types.h" +#include "nnet/u2_nnet.h" +#include "recognizer/recognizer_controller.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); +DEFINE_int32(njob, 3, "njob"); + +using std::string; +using std::vector; + +void SplitUtt(string wavlist_file, + vector>* uttlists, + vector>* wavlists, + int njob) { + vector wavlist; + wavlists->resize(njob); + uttlists->resize(njob); + ppspeech::ReadFileToVector(wavlist_file, &wavlist); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + string utt_str = wavlist[idx]; + vector utt_wav = ppspeech::StrSplit(utt_str, " \t"); + LOG(INFO) << utt_wav[0]; + CHECK_EQ(utt_wav.size(), size_t(2)); + uttlists->at(idx % njob).push_back(utt_wav[0]); + wavlists->at(idx % njob).push_back(utt_wav[1]); + } +} + +void recognizer_func(ppspeech::RecognizerController* recognizer_controller, + std::vector wavlist, + std::vector uttlist, + std::vector* results) { + int32 num_done = 0, num_err = 0; + double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; + double tot_decode_time = 0.0; + int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate; + if (wavlist.empty()) return; + + results->reserve(wavlist.size()); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + std::string utt = uttlist[idx]; + std::string wav_file = wavlist[idx]; + std::ifstream infile; + infile.open(wav_file, std::ifstream::in); + kaldi::WaveData wave_data; + wave_data.Read(infile); + int32 recog_id = -1; + while (recog_id == -1) { + recog_id = recognizer_controller->GetRecognizerInstanceId(); + } + recognizer_controller->InitDecoder(recog_id); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer local_timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + + recognizer_controller->Accept(wav_chunk, recog_id); + // no overlap + sample_offset += cur_chunk_size; + } + recognizer_controller->SetInputFinished(recog_id); + CHECK(sample_offset == tot_samples); + std::string result = recognizer_controller->GetFinalResult(recog_id); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + LOG(INFO) << " the result of " << utt << " is empty"; + result = " "; + } + + tot_decode_time += local_timer.Elapsed(); + LOG(INFO) << utt << " " << result; + LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur + << " cost: " << local_timer.Elapsed(); + + results->push_back(result); + ++num_done; + } + LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); + LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; + LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; +} + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + int njob = FLAGS_njob; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + ppspeech::RecognizerResource resource = + ppspeech::RecognizerResource::InitFromFlags(); + ppspeech::RecognizerController recognizer_controller(njob, resource); + ThreadPool threadpool(njob); + vector> wavlist; + vector> uttlist; + vector> resultlist(njob); + vector> futurelist; + SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob); + for (size_t i = 0; i < njob; ++i) { + std::future f = threadpool.enqueue(recognizer_func, + &recognizer_controller, + wavlist[i], + uttlist[i], + &resultlist[i]); + futurelist.push_back(std::move(f)); + } + + for (size_t i = 0; i < njob; ++i) { + futurelist[i].get(); + } + + for (size_t idx = 0; idx < njob; ++idx) { + for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) { + string utt = uttlist[idx][utt_idx]; + string result = resultlist[idx][utt_idx]; + result_writer.Write(utt, result); + } + } + return 0; +} diff --git a/runtime/engine/asr/recognizer/recognizer_batch_main2.cc b/runtime/engine/asr/recognizer/recognizer_batch_main2.cc new file mode 100644 index 00000000..fc99bf0b --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_batch_main2.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/base/thread_pool.h" +#include "common/utils/file_utils.h" +#include "common/utils/strings.h" +#include "decoder/param.h" +#include "frontend/wave-reader.h" +#include "kaldi/util/table-types.h" +#include "nnet/u2_nnet.h" +#include "recognizer/recognizer.h" + +DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); +DEFINE_string(result_wspecifier, "", "test result wspecifier"); +DEFINE_double(streaming_chunk, 0.36, "streaming feature chunk size"); +DEFINE_int32(sample_rate, 16000, "sample rate"); +DEFINE_int32(njob, 3, "njob"); + +using std::string; +using std::vector; + +void SplitUtt(string wavlist_file, + vector>* uttlists, + vector>* wavlists, + int njob) { + vector wavlist; + wavlists->resize(njob); + uttlists->resize(njob); + ppspeech::ReadFileToVector(wavlist_file, &wavlist); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + string utt_str = wavlist[idx]; + vector utt_wav = ppspeech::StrSplit(utt_str, " \t"); + LOG(INFO) << utt_wav[0]; + CHECK_EQ(utt_wav.size(), size_t(2)); + uttlists->at(idx % njob).push_back(utt_wav[0]); + wavlists->at(idx % njob).push_back(utt_wav[1]); + } +} + +void recognizer_func(std::vector wavlist, + std::vector uttlist, + std::vector* results) { + int32 num_done = 0, num_err = 0; + double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; + double tot_decode_time = 0.0; + int chunk_sample_size = FLAGS_streaming_chunk * FLAGS_sample_rate; + if (wavlist.empty()) return; + + results->reserve(wavlist.size()); + for (size_t idx = 0; idx < wavlist.size(); ++idx) { + std::string utt = uttlist[idx]; + std::string wav_file = wavlist[idx]; + std::ifstream infile; + infile.open(wav_file, std::ifstream::in); + kaldi::WaveData wave_data; + wave_data.Read(infile); + int32 recog_id = -1; + while (recog_id == -1) { + recog_id = GetRecognizerInstanceId(); + } + InitDecoder(recog_id); + LOG(INFO) << "utt: " << utt; + LOG(INFO) << "wav dur: " << wave_data.Duration() << " sec."; + double dur = wave_data.Duration(); + tot_wav_duration += dur; + + int32 this_channel = 0; + kaldi::SubVector waveform(wave_data.Data(), + this_channel); + int tot_samples = waveform.Dim(); + LOG(INFO) << "wav len (sample): " << tot_samples; + + int sample_offset = 0; + kaldi::Timer local_timer; + + while (sample_offset < tot_samples) { + int cur_chunk_size = + std::min(chunk_sample_size, tot_samples - sample_offset); + + std::vector wav_chunk(cur_chunk_size); + for (int i = 0; i < cur_chunk_size; ++i) { + wav_chunk[i] = waveform(sample_offset + i); + } + + AcceptData(wav_chunk, recog_id); + // no overlap + sample_offset += cur_chunk_size; + } + SetInputFinished(recog_id); + CHECK(sample_offset == tot_samples); + std::string result = GetFinalResult(recog_id); + if (result.empty()) { + // the TokenWriter can not write empty string. + ++num_err; + LOG(INFO) << " the result of " << utt << " is empty"; + result = " "; + } + + tot_decode_time += local_timer.Elapsed(); + LOG(INFO) << utt << " " << result; + LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur + << " cost: " << local_timer.Elapsed(); + + results->push_back(result); + ++num_done; + } + LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); + LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; + LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; +} + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + + int sample_rate = FLAGS_sample_rate; + float streaming_chunk = FLAGS_streaming_chunk; + int chunk_sample_size = streaming_chunk * sample_rate; + kaldi::TokenWriter result_writer(FLAGS_result_wspecifier); + int njob = FLAGS_njob; + LOG(INFO) << "sr: " << sample_rate; + LOG(INFO) << "chunk size (s): " << streaming_chunk; + LOG(INFO) << "chunk size (sample): " << chunk_sample_size; + + InitRecognizer(FLAGS_model_path, FLAGS_word_symbol_table, FLAGS_graph_path, njob); + ThreadPool threadpool(njob); + vector> wavlist; + vector> uttlist; + vector> resultlist(njob); + vector> futurelist; + SplitUtt(FLAGS_wav_rspecifier, &uttlist, &wavlist, njob); + for (size_t i = 0; i < njob; ++i) { + std::future f = threadpool.enqueue(recognizer_func, + wavlist[i], + uttlist[i], + &resultlist[i]); + futurelist.push_back(std::move(f)); + } + + for (size_t i = 0; i < njob; ++i) { + futurelist[i].get(); + } + + for (size_t idx = 0; idx < njob; ++idx) { + for (size_t utt_idx = 0; utt_idx < uttlist[idx].size(); ++utt_idx) { + string utt = uttlist[idx][utt_idx]; + string result = resultlist[idx][utt_idx]; + result_writer.Write(utt, result); + } + } + return 0; +} diff --git a/runtime/engine/asr/recognizer/recognizer_controller.cc b/runtime/engine/asr/recognizer/recognizer_controller.cc new file mode 100644 index 00000000..ef549263 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_controller.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "recognizer/recognizer_controller.h" +#include "nnet/u2_nnet.h" + +namespace ppspeech { + +RecognizerController::RecognizerController(int num_worker, RecognizerResource resource) { + recognizer_workers.resize(num_worker); + for (size_t i = 0; i < num_worker; ++i) { + recognizer_workers[i].reset(new ppspeech::RecognizerControllerImpl(resource)); + waiting_workers.push(i); + } +} + +int RecognizerController::GetRecognizerInstanceId() { + if (waiting_workers.empty()) { + return -1; + } + int idx = -1; + { + std::unique_lock lock(mutex_); + idx = waiting_workers.front(); + waiting_workers.pop(); + } + return idx; +} + +RecognizerController::~RecognizerController() { + for (size_t i = 0; i < recognizer_workers.size(); ++i) { + recognizer_workers[i]->WaitFinished(); + } +} + +void RecognizerController::InitDecoder(int idx) { + recognizer_workers[idx]->InitDecoder(); +} + +std::string RecognizerController::GetFinalResult(int idx) { + recognizer_workers[idx]->WaitDecoderFinished(); + recognizer_workers[idx]->AttentionRescoring(); + std::string result = recognizer_workers[idx]->GetFinalResult(); + { + std::unique_lock lock(mutex_); + waiting_workers.push(idx); + } + return result; +} + +void RecognizerController::Accept(std::vector data, int idx) { + recognizer_workers[idx]->Accept(data); +} + +void RecognizerController::SetInputFinished(int idx) { + recognizer_workers[idx]->SetInputFinished(); +} + +} diff --git a/runtime/engine/asr/recognizer/recognizer_controller.h b/runtime/engine/asr/recognizer/recognizer_controller.h new file mode 100644 index 00000000..16a8dd13 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_controller.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "recognizer/recognizer_controller_impl.h" + +namespace ppspeech { + +class RecognizerController { + public: + explicit RecognizerController(int num_worker, RecognizerResource resource); + ~RecognizerController(); + int GetRecognizerInstanceId(); + void InitDecoder(int idx); + void Accept(std::vector data, int idx); + void SetInputFinished(int idx); + std::string GetFinalResult(int idx); + + private: + std::queue waiting_workers; + std::mutex mutex_; + std::vector> recognizer_workers; + + DISALLOW_COPY_AND_ASSIGN(RecognizerController); +}; + +} \ No newline at end of file diff --git a/speechx/speechx/recognizer/u2_recognizer.cc b/runtime/engine/asr/recognizer/recognizer_controller_impl.cc similarity index 57% rename from speechx/speechx/recognizer/u2_recognizer.cc rename to runtime/engine/asr/recognizer/recognizer_controller_impl.cc index d1d308eb..cc4d3c78 100644 --- a/speechx/speechx/recognizer/u2_recognizer.cc +++ b/runtime/engine/asr/recognizer/recognizer_controller_impl.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,86 +12,180 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "recognizer/u2_recognizer.h" - -#include "nnet/u2_nnet.h" +#include "recognizer/recognizer_controller_impl.h" +#include "decoder/ctc_prefix_beam_search_decoder.h" +#include "common/utils/strings.h" namespace ppspeech { -using kaldi::BaseFloat; -using kaldi::SubVector; -using kaldi::Vector; -using kaldi::VectorBase; -using std::unique_ptr; -using std::vector; - -U2Recognizer::U2Recognizer(const U2RecognizerResource& resource) - : opts_(resource) { +RecognizerControllerImpl::RecognizerControllerImpl(const RecognizerResource& resource) +: opts_(resource) { + BaseFloat am_scale = resource.acoustic_scale; + BaseFloat blank_threshold = resource.blank_threshold; const FeaturePipelineOptions& feature_opts = resource.feature_pipeline_opts; - feature_pipeline_.reset(new FeaturePipeline(feature_opts)); + std::shared_ptr feature_pipeline( + new FeaturePipeline(feature_opts)); + std::shared_ptr nnet; +#ifndef USE_ONNX + nnet = resource.nnet->Clone(); +#else + if (resource.model_opts.with_onnx_model){ + nnet.reset(new U2OnnxNnet(resource.model_opts)); + } else { + nnet = resource.nnet->Clone(); + } +#endif + nnet_producer_.reset(new NnetProducer(nnet, feature_pipeline, blank_threshold)); + nnet_thread_ = std::thread(RunNnetEvaluation, this); + + decodable_.reset(new Decodable(nnet_producer_, am_scale)); + if (resource.decoder_opts.tlg_decoder_opts.fst_path.empty()) { + LOG(INFO) << "Init PrefixBeamSearch Decoder"; + decoder_ = std::make_unique( + resource.decoder_opts.ctc_prefix_search_opts); + } else { + LOG(INFO) << "Init TLGDecoder"; + decoder_ = std::make_unique( + resource.decoder_opts.tlg_decoder_opts); + } - std::shared_ptr nnet(new U2Nnet(resource.model_opts)); + symbol_table_ = decoder_->WordSymbolTable(); + global_frame_offset_ = 0; + input_finished_ = false; + num_frames_ = 0; + result_.clear(); +} - BaseFloat am_scale = resource.acoustic_scale; - decodable_.reset(new Decodable(nnet, feature_pipeline_, am_scale)); +RecognizerControllerImpl::~RecognizerControllerImpl() { + WaitFinished(); +} - CHECK_NE(resource.vocab_path, ""); - decoder_.reset(new CTCPrefixBeamSearch( - resource.vocab_path, resource.decoder_opts.ctc_prefix_search_opts)); +void RecognizerControllerImpl::Reset() { + nnet_producer_->Reset(); +} - unit_table_ = decoder_->VocabTable(); - symbol_table_ = unit_table_; +void RecognizerControllerImpl::RunDecoder(RecognizerControllerImpl* me) { + me->RunDecoderInternal(); +} - input_finished_ = false; +void RecognizerControllerImpl::RunDecoderInternal() { + LOG(INFO) << "DecoderInternal begin"; + while (!nnet_producer_->IsFinished()) { + nnet_condition_.notify_one(); + decoder_->AdvanceDecode(decodable_); + } + decoder_->AdvanceDecode(decodable_); + UpdateResult(false); + LOG(INFO) << "DecoderInternal exit"; +} - Reset(); +void RecognizerControllerImpl::WaitDecoderFinished() { + if (decoder_thread_.joinable()) decoder_thread_.join(); } -void U2Recognizer::Reset() { - global_frame_offset_ = 0; - num_frames_ = 0; - result_.clear(); +void RecognizerControllerImpl::RunNnetEvaluation(RecognizerControllerImpl* me) { + me->RunNnetEvaluationInternal(); +} - decodable_->Reset(); - decoder_->Reset(); +void RecognizerControllerImpl::SetInputFinished() { + nnet_producer_->SetInputFinished(); + nnet_condition_.notify_one(); + LOG(INFO) << "Set Input Finished"; } -void U2Recognizer::ResetContinuousDecoding() { - global_frame_offset_ = num_frames_; +void RecognizerControllerImpl::WaitFinished() { + abort_ = true; + LOG(INFO) << "nnet wait finished"; + nnet_condition_.notify_one(); + if (nnet_thread_.joinable()) { + nnet_thread_.join(); + } +} + +void RecognizerControllerImpl::RunNnetEvaluationInternal() { + bool result = false; + LOG(INFO) << "NnetEvaluationInteral begin"; + while (!abort_) { + std::unique_lock lock(nnet_mutex_); + nnet_condition_.wait(lock); + do { + result = nnet_producer_->Compute(); + decoder_condition_.notify_one(); + } while (result); + } + LOG(INFO) << "NnetEvaluationInteral exit"; +} + +void RecognizerControllerImpl::Accept(std::vector data) { + nnet_producer_->Accept(data); + nnet_condition_.notify_one(); +} + +void RecognizerControllerImpl::InitDecoder() { + global_frame_offset_ = 0; + input_finished_ = false; num_frames_ = 0; result_.clear(); decodable_->Reset(); decoder_->Reset(); + decoder_thread_ = std::thread(RunDecoder, this); } +void RecognizerControllerImpl::AttentionRescoring() { + decoder_->FinalizeSearch(); + UpdateResult(false); -void U2Recognizer::Accept(const VectorBase& waves) { - kaldi::Timer timer; - feature_pipeline_->Accept(waves); - VLOG(1) << "feed waves cost: " << timer.Elapsed() << " sec. " << waves.Dim() - << " samples."; -} + // No need to do rescoring + if (0.0 == opts_.decoder_opts.rescoring_weight) { + LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!"; + return; + } + LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!"; + // Inputs() returns N-best input ids, which is the basic unit for rescoring + // In CtcPrefixBeamSearch, inputs are the same to outputs + const auto& hypotheses = decoder_->Inputs(); + int num_hyps = hypotheses.size(); + if (num_hyps <= 0) { + return; + } -void U2Recognizer::Decode() { - decoder_->AdvanceDecode(decodable_); - UpdateResult(false); -} + std::vector rescoring_score; + decodable_->AttentionRescoring( + hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score); -void U2Recognizer::Rescoring() { - // Do attention Rescoring - AttentionRescoring(); + // combine ctc score and rescoring score + for (size_t i = 0; i < num_hyps; i++) { + VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i] + << " ctc_score: " << result_[i].score + << " rescoring_weight: " << opts_.decoder_opts.rescoring_weight + << " ctc_weight: " << opts_.decoder_opts.ctc_weight; + result_[i].score = + opts_.decoder_opts.rescoring_weight * rescoring_score[i] + + opts_.decoder_opts.ctc_weight * result_[i].score; + + VLOG(3) << "hyp: " << result_[0].sentence + << " score: " << result_[0].score; + } + + std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc); + VLOG(3) << "result: " << result_[0].sentence + << " score: " << result_[0].score; } -void U2Recognizer::UpdateResult(bool finish) { +std::string RecognizerControllerImpl::GetFinalResult() { return result_[0].sentence; } + +std::string RecognizerControllerImpl::GetPartialResult() { return result_[0].sentence; } + +void RecognizerControllerImpl::UpdateResult(bool finish) { const auto& hypotheses = decoder_->Outputs(); const auto& inputs = decoder_->Inputs(); const auto& likelihood = decoder_->Likelihood(); const auto& times = decoder_->Times(); result_.clear(); - CHECK_EQ(hypotheses.size(), likelihood.size()); + CHECK_EQ(inputs.size(), likelihood.size()); for (size_t i = 0; i < hypotheses.size(); i++) { const std::vector& hypothesis = hypotheses[i]; @@ -99,21 +193,16 @@ void U2Recognizer::UpdateResult(bool finish) { path.score = likelihood[i]; for (size_t j = 0; j < hypothesis.size(); j++) { std::string word = symbol_table_->Find(hypothesis[j]); - // A detailed explanation of this if-else branch can be found in - // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058 - if (decoder_->Type() == kWfstBeamSearch) { - path.sentence += (" " + word); - } else { - path.sentence += (word); - } + path.sentence += (" " + word); } + path.sentence = DelBlank(path.sentence); // TimeStamp is only supported in final result // TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to // various FST operations when building the decoding graph. So here we // use time stamp of the input(e2e model unit), which is more accurate, // and it requires the symbol table of the e2e model used in training. - if (unit_table_ != nullptr && finish) { + if (symbol_table_ != nullptr && finish) { int offset = global_frame_offset_ * FrameShiftInMs(); const std::vector& input = inputs[i]; @@ -121,7 +210,7 @@ void U2Recognizer::UpdateResult(bool finish) { CHECK_EQ(input.size(), time_stamp.size()); for (size_t j = 0; j < input.size(); j++) { - std::string word = unit_table_->Find(input[j]); + std::string word = symbol_table_->Find(input[j]); int start = time_stamp[j] * FrameShiftInMs() - time_stamp_gap_ > 0 @@ -163,56 +252,4 @@ void U2Recognizer::UpdateResult(bool finish) { } } -void U2Recognizer::AttentionRescoring() { - decoder_->FinalizeSearch(); - UpdateResult(true); - - // No need to do rescoring - if (0.0 == opts_.decoder_opts.rescoring_weight) { - LOG_EVERY_N(WARNING, 3) << "Not do AttentionRescoring!"; - return; - } - LOG_EVERY_N(WARNING, 3) << "Do AttentionRescoring!"; - - // Inputs() returns N-best input ids, which is the basic unit for rescoring - // In CtcPrefixBeamSearch, inputs are the same to outputs - const auto& hypotheses = decoder_->Inputs(); - int num_hyps = hypotheses.size(); - if (num_hyps <= 0) { - return; - } - - std::vector rescoring_score; - decodable_->AttentionRescoring( - hypotheses, opts_.decoder_opts.reverse_weight, &rescoring_score); - - // combine ctc score and rescoring score - for (size_t i = 0; i < num_hyps; i++) { - VLOG(3) << "hyp " << i << " rescoring_score: " << rescoring_score[i] - << " ctc_score: " << result_[i].score - << " rescoring_weight: " << opts_.decoder_opts.rescoring_weight - << " ctc_weight: " << opts_.decoder_opts.ctc_weight; - result_[i].score = - opts_.decoder_opts.rescoring_weight * rescoring_score[i] + - opts_.decoder_opts.ctc_weight * result_[i].score; - - VLOG(3) << "hyp: " << result_[0].sentence - << " score: " << result_[0].score; - } - - std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc); - VLOG(3) << "result: " << result_[0].sentence - << " score: " << result_[0].score; -} - -std::string U2Recognizer::GetFinalResult() { return result_[0].sentence; } - -std::string U2Recognizer::GetPartialResult() { return result_[0].sentence; } - -void U2Recognizer::SetFinished() { - feature_pipeline_->SetFinished(); - input_finished_ = true; -} - - } // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer_controller_impl.h b/runtime/engine/asr/recognizer/recognizer_controller_impl.h new file mode 100644 index 00000000..3ff6faa6 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_controller_impl.h @@ -0,0 +1,89 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "decoder/common.h" +#include "fst/fstlib.h" +#include "fst/symbol-table.h" +#include "nnet/u2_nnet.h" +#include "nnet/nnet_producer.h" +#ifdef USE_ONNX +#include "nnet/u2_onnx_nnet.h" +#endif +#include "nnet/decodable.h" +#include "recognizer/recognizer_resource.h" + +#include + +namespace ppspeech { + +class RecognizerControllerImpl { + public: + explicit RecognizerControllerImpl(const RecognizerResource& resource); + ~RecognizerControllerImpl(); + void Accept(std::vector data); + void InitDecoder(); + void SetInputFinished(); + std::string GetFinalResult(); + std::string GetPartialResult(); + void Rescoring(); + void Reset(); + void WaitDecoderFinished(); + void WaitFinished(); + void AttentionRescoring(); + bool DecodedSomething() const { + return !result_.empty() && !result_[0].sentence.empty(); + } + int FrameShiftInMs() const { + return 1; //todo + } + + private: + + static void RunNnetEvaluation(RecognizerControllerImpl* me); + void RunNnetEvaluationInternal(); + static void RunDecoder(RecognizerControllerImpl* me); + void RunDecoderInternal(); + void UpdateResult(bool finish = false); + + std::shared_ptr decodable_; + std::unique_ptr decoder_; + std::shared_ptr nnet_producer_; + + // e2e unit symbol table + std::shared_ptr symbol_table_ = nullptr; + std::vector result_; + + RecognizerResource opts_; + bool abort_ = false; + // global decoded frame offset + int global_frame_offset_; + // cur decoded frame num + int num_frames_; + // timestamp gap between words in a sentence + const int time_stamp_gap_ = 100; + bool input_finished_; + + std::mutex nnet_mutex_; + std::mutex decoder_mutex_; + std::condition_variable nnet_condition_; + std::condition_variable decoder_condition_; + std::thread nnet_thread_; + std::thread decoder_thread_; + + DISALLOW_COPY_AND_ASSIGN(RecognizerControllerImpl); +}; + +} diff --git a/runtime/engine/asr/recognizer/recognizer_instance.cc b/runtime/engine/asr/recognizer/recognizer_instance.cc new file mode 100644 index 00000000..b9019ec4 --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_instance.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "recognizer/recognizer_instance.h" + + +namespace ppspeech { + +RecognizerInstance& RecognizerInstance::GetInstance() { + static RecognizerInstance instance; + return instance; +} + +bool RecognizerInstance::Init(const std::string& model_file, + const std::string& word_symbol_table_file, + const std::string& fst_file, + int num_instance) { + RecognizerResource resource = RecognizerResource::InitFromFlags(); + resource.model_opts.model_path = model_file; + //resource.vocab_path = word_symbol_table_file; + if (!fst_file.empty()) { + resource.decoder_opts.tlg_decoder_opts.fst_path = fst_file; + resource.decoder_opts.tlg_decoder_opts.fst_path = word_symbol_table_file; + } else { + resource.decoder_opts.ctc_prefix_search_opts.word_symbol_table = + word_symbol_table_file; + } + recognizer_controller_ = std::make_unique(num_instance, resource); + return true; +} + +void RecognizerInstance::InitDecoder(int idx) { + recognizer_controller_->InitDecoder(idx); + return; +} + +int RecognizerInstance::GetRecognizerInstanceId() { + return recognizer_controller_->GetRecognizerInstanceId(); +} + +void RecognizerInstance::Accept(const std::vector& waves, int idx) const { + recognizer_controller_->Accept(waves, idx); + return; +} + +void RecognizerInstance::SetInputFinished(int idx) const { + recognizer_controller_->SetInputFinished(idx); + return; +} + +std::string RecognizerInstance::GetResult(int idx) const { + return recognizer_controller_->GetFinalResult(idx); +} + +} \ No newline at end of file diff --git a/runtime/engine/asr/recognizer/recognizer_instance.h b/runtime/engine/asr/recognizer/recognizer_instance.h new file mode 100644 index 00000000..ef8f524d --- /dev/null +++ b/runtime/engine/asr/recognizer/recognizer_instance.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" +#include "recognizer/recognizer_controller.h" + +namespace ppspeech { + +class RecognizerInstance { + public: + static RecognizerInstance& GetInstance(); + RecognizerInstance() {} + ~RecognizerInstance() {} + bool Init(const std::string& model_file, + const std::string& word_symbol_table_file, + const std::string& fst_file, + int num_instance); + int GetRecognizerInstanceId(); + void InitDecoder(int idx); + void Accept(const std::vector& waves, int idx) const; + void SetInputFinished(int idx) const; + std::string GetResult(int idx) const; + + private: + std::unique_ptr recognizer_controller_; +}; + + +} // namespace ppspeech diff --git a/speechx/speechx/recognizer/u2_recognizer_main.cc b/runtime/engine/asr/recognizer/recognizer_main.cc similarity index 75% rename from speechx/speechx/recognizer/u2_recognizer_main.cc rename to runtime/engine/asr/recognizer/recognizer_main.cc index d7c58407..99b7b4dd 100644 --- a/speechx/speechx/recognizer/u2_recognizer_main.cc +++ b/runtime/engine/asr/recognizer/recognizer_main.cc @@ -13,9 +13,9 @@ // limitations under the License. #include "decoder/param.h" -#include "kaldi/feat/wave-reader.h" +#include "frontend/wave-reader.h" #include "kaldi/util/table-types.h" -#include "recognizer/u2_recognizer.h" +#include "recognizer/recognizer_controller.h" DEFINE_string(wav_rspecifier, "", "test feature rspecifier"); DEFINE_string(result_wspecifier, "", "test result wspecifier"); @@ -31,6 +31,7 @@ int main(int argc, char* argv[]) { int32 num_done = 0, num_err = 0; double tot_wav_duration = 0.0; + double tot_attention_rescore_time = 0.0; double tot_decode_time = 0.0; kaldi::SequentialTableReader wav_reader( @@ -44,11 +45,13 @@ int main(int argc, char* argv[]) { LOG(INFO) << "chunk size (s): " << streaming_chunk; LOG(INFO) << "chunk size (sample): " << chunk_sample_size; - ppspeech::U2RecognizerResource resource = - ppspeech::U2RecognizerResource::InitFromFlags(); - ppspeech::U2Recognizer recognizer(resource); + ppspeech::RecognizerResource resource = + ppspeech::RecognizerResource::InitFromFlags(); + std::shared_ptr recognizer_ptr( + new ppspeech::RecognizerControllerImpl(resource)); for (; !wav_reader.Done(); wav_reader.Next()) { + recognizer_ptr->InitDecoder(); std::string utt = wav_reader.Key(); const kaldi::WaveData& wave_data = wav_reader.Value(); LOG(INFO) << "utt: " << utt; @@ -63,45 +66,32 @@ int main(int argc, char* argv[]) { LOG(INFO) << "wav len (sample): " << tot_samples; int sample_offset = 0; - int cnt = 0; - kaldi::Timer timer; kaldi::Timer local_timer; while (sample_offset < tot_samples) { int cur_chunk_size = std::min(chunk_sample_size, tot_samples - sample_offset); - kaldi::Vector wav_chunk(cur_chunk_size); + std::vector wav_chunk(cur_chunk_size); for (int i = 0; i < cur_chunk_size; ++i) { - wav_chunk(i) = waveform(sample_offset + i); + wav_chunk[i] = waveform(sample_offset + i); } - // wav_chunk = waveform.Range(sample_offset + i, cur_chunk_size); - recognizer.Accept(wav_chunk); - if (cur_chunk_size < chunk_sample_size) { - recognizer.SetFinished(); - } - recognizer.Decode(); - if (recognizer.DecodedSomething()) { - LOG(INFO) << "Pratial result: " << cnt << " " - << recognizer.GetPartialResult(); - } + recognizer_ptr->Accept(wav_chunk); // no overlap sample_offset += cur_chunk_size; - cnt++; } CHECK(sample_offset == tot_samples); + recognizer_ptr->SetInputFinished(); + recognizer_ptr->WaitDecoderFinished(); - // second pass decoding - recognizer.Rescoring(); - - tot_decode_time += timer.Elapsed(); - - std::string result = recognizer.GetFinalResult(); - - recognizer.Reset(); + kaldi::Timer timer; + recognizer_ptr->AttentionRescoring(); + float rescore_time = timer.Elapsed(); + tot_attention_rescore_time += rescore_time; + std::string result = recognizer_ptr->GetFinalResult(); if (result.empty()) { // the TokenWriter can not write empty string. ++num_err; @@ -109,17 +99,20 @@ int main(int argc, char* argv[]) { continue; } + tot_decode_time += local_timer.Elapsed(); LOG(INFO) << utt << " " << result; LOG(INFO) << " RTF: " << local_timer.Elapsed() / dur << " dur: " << dur - << " cost: " << local_timer.Elapsed(); + << " cost: " << local_timer.Elapsed() << " rescore:" << rescore_time; result_writer.Write(utt, result); ++num_done; } + recognizer_ptr->WaitFinished(); LOG(INFO) << "Done " << num_done << " out of " << (num_err + num_done); LOG(INFO) << "total wav duration is: " << tot_wav_duration << " sec"; LOG(INFO) << "total decode cost:" << tot_decode_time << " sec"; + LOG(INFO) << "total rescore cost:" << tot_attention_rescore_time << " sec"; LOG(INFO) << "RTF is: " << tot_decode_time / tot_wav_duration; } diff --git a/speechx/speechx/recognizer/u2_recognizer.h b/runtime/engine/asr/recognizer/recognizer_resource.h similarity index 54% rename from speechx/speechx/recognizer/u2_recognizer.h rename to runtime/engine/asr/recognizer/recognizer_resource.h index 25850863..064a5b5b 100644 --- a/speechx/speechx/recognizer/u2_recognizer.h +++ b/runtime/engine/asr/recognizer/recognizer_resource.h @@ -1,27 +1,8 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - #pragma once -#include "decoder/common.h" #include "decoder/ctc_beam_search_opt.h" -#include "decoder/ctc_prefix_beam_search_decoder.h" -#include "decoder/decoder_itf.h" -#include "frontend/audio/feature_pipeline.h" -#include "fst/fstlib.h" -#include "fst/symbol-table.h" -#include "nnet/decodable.h" +#include "decoder/ctc_tlg_decoder.h" +#include "frontend/feature_pipeline.h" DECLARE_int32(nnet_decoder_chunk); DECLARE_int32(num_left_chunks); @@ -30,9 +11,9 @@ DECLARE_double(rescoring_weight); DECLARE_double(reverse_weight); DECLARE_int32(nbest); DECLARE_int32(blank); - DECLARE_double(acoustic_scale); -DECLARE_string(vocab_path); +DECLARE_double(blank_threshold); +DECLARE_string(word_symbol_table); namespace ppspeech { @@ -59,6 +40,7 @@ struct DecodeOptions { // CtcEndpointConfig ctc_endpoint_opts; CTCBeamSearchOptions ctc_prefix_search_opts{}; + TLGDecoderOptions tlg_decoder_opts{}; static DecodeOptions InitFromFlags() { DecodeOptions decoder_opts; @@ -70,6 +52,11 @@ struct DecodeOptions { decoder_opts.ctc_prefix_search_opts.blank = FLAGS_blank; decoder_opts.ctc_prefix_search_opts.first_beam_size = FLAGS_nbest; decoder_opts.ctc_prefix_search_opts.second_beam_size = FLAGS_nbest; + decoder_opts.ctc_prefix_search_opts.word_symbol_table = + FLAGS_word_symbol_table; + decoder_opts.tlg_decoder_opts = + ppspeech::TLGDecoderOptions::InitFromFlags(); + LOG(INFO) << "chunk_size: " << decoder_opts.chunk_size; LOG(INFO) << "num_left_chunks: " << decoder_opts.num_left_chunks; LOG(INFO) << "ctc_weight: " << decoder_opts.ctc_weight; @@ -82,19 +69,20 @@ struct DecodeOptions { } }; -struct U2RecognizerResource { +struct RecognizerResource { + // decodable opt kaldi::BaseFloat acoustic_scale{1.0}; - std::string vocab_path{}; + kaldi::BaseFloat blank_threshold{0.98}; FeaturePipelineOptions feature_pipeline_opts{}; ModelOptions model_opts{}; DecodeOptions decoder_opts{}; + std::shared_ptr nnet; - static U2RecognizerResource InitFromFlags() { - U2RecognizerResource resource; - resource.vocab_path = FLAGS_vocab_path; + static RecognizerResource InitFromFlags() { + RecognizerResource resource; resource.acoustic_scale = FLAGS_acoustic_scale; - LOG(INFO) << "vocab path: " << resource.vocab_path; + resource.blank_threshold = FLAGS_blank_threshold; LOG(INFO) << "acoustic_scale: " << resource.acoustic_scale; resource.feature_pipeline_opts = @@ -104,69 +92,17 @@ struct U2RecognizerResource { << resource.feature_pipeline_opts.assembler_opts.fill_zero; resource.model_opts = ppspeech::ModelOptions::InitFromFlags(); resource.decoder_opts = ppspeech::DecodeOptions::InitFromFlags(); + #ifndef USE_ONNX + resource.nnet.reset(new U2Nnet(resource.model_opts)); + #else + if (resource.model_opts.with_onnx_model){ + resource.nnet.reset(new U2OnnxNnet(resource.model_opts)); + } else { + resource.nnet.reset(new U2Nnet(resource.model_opts)); + } + #endif return resource; } }; - -class U2Recognizer { - public: - explicit U2Recognizer(const U2RecognizerResource& resouce); - void Reset(); - void ResetContinuousDecoding(); - - void Accept(const kaldi::VectorBase& waves); - void Decode(); - void Rescoring(); - - - std::string GetFinalResult(); - std::string GetPartialResult(); - - void SetFinished(); - bool IsFinished() { return input_finished_; } - - bool DecodedSomething() const { - return !result_.empty() && !result_[0].sentence.empty(); - } - - - int FrameShiftInMs() const { - // one decoder frame length in ms - return decodable_->Nnet()->SubsamplingRate() * - feature_pipeline_->FrameShift(); - } - - - const std::vector& Result() const { return result_; } - - private: - void AttentionRescoring(); - void UpdateResult(bool finish = false); - - private: - U2RecognizerResource opts_; - - // std::shared_ptr resource_; - // U2RecognizerResource resource_; - std::shared_ptr feature_pipeline_; - std::shared_ptr decodable_; - std::unique_ptr decoder_; - - // e2e unit symbol table - std::shared_ptr unit_table_ = nullptr; - std::shared_ptr symbol_table_ = nullptr; - - std::vector result_; - - // global decoded frame offset - int global_frame_offset_; - // cur decoded frame num - int num_frames_; - // timestamp gap between words in a sentence - const int time_stamp_gap_ = 100; - - bool input_finished_; -}; - -} // namespace ppspeech \ No newline at end of file +} //namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/asr/server/CMakeLists.txt b/runtime/engine/asr/server/CMakeLists.txt new file mode 100644 index 00000000..566b42ee --- /dev/null +++ b/runtime/engine/asr/server/CMakeLists.txt @@ -0,0 +1 @@ +#add_subdirectory(websocket) diff --git a/speechx/speechx/protocol/websocket/CMakeLists.txt b/runtime/engine/asr/server/websocket/CMakeLists.txt similarity index 98% rename from speechx/speechx/protocol/websocket/CMakeLists.txt rename to runtime/engine/asr/server/websocket/CMakeLists.txt index cafbbec7..9991e47b 100644 --- a/speechx/speechx/protocol/websocket/CMakeLists.txt +++ b/runtime/engine/asr/server/websocket/CMakeLists.txt @@ -10,4 +10,4 @@ target_link_libraries(websocket_server_main PUBLIC fst websocket ${DEPS}) add_executable(websocket_client_main ${CMAKE_CURRENT_SOURCE_DIR}/websocket_client_main.cc) target_include_directories(websocket_client_main PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) -target_link_libraries(websocket_client_main PUBLIC fst websocket ${DEPS}) +target_link_libraries(websocket_client_main PUBLIC fst websocket ${DEPS}) \ No newline at end of file diff --git a/speechx/speechx/protocol/websocket/websocket_client.cc b/runtime/engine/asr/server/websocket/websocket_client.cc similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_client.cc rename to runtime/engine/asr/server/websocket/websocket_client.cc diff --git a/speechx/speechx/protocol/websocket/websocket_client.h b/runtime/engine/asr/server/websocket/websocket_client.h similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_client.h rename to runtime/engine/asr/server/websocket/websocket_client.h diff --git a/speechx/speechx/protocol/websocket/websocket_client_main.cc b/runtime/engine/asr/server/websocket/websocket_client_main.cc similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_client_main.cc rename to runtime/engine/asr/server/websocket/websocket_client_main.cc diff --git a/speechx/speechx/protocol/websocket/websocket_server.cc b/runtime/engine/asr/server/websocket/websocket_server.cc similarity index 96% rename from speechx/speechx/protocol/websocket/websocket_server.cc rename to runtime/engine/asr/server/websocket/websocket_server.cc index 14f2f6e9..d1bed1ca 100644 --- a/speechx/speechx/protocol/websocket/websocket_server.cc +++ b/runtime/engine/asr/server/websocket/websocket_server.cc @@ -32,14 +32,14 @@ void ConnectionHandler::OnSpeechStart() { decode_thread_ = std::make_shared( &ConnectionHandler::DecodeThreadFunc, this); got_start_tag_ = true; - LOG(INFO) << "Server: Recieved speech start signal, start reading speech"; + LOG(INFO) << "Server: Received speech start signal, start reading speech"; json::value rv = {{"status", "ok"}, {"type", "server_ready"}}; ws_.text(true); ws_.write(asio::buffer(json::serialize(rv))); } void ConnectionHandler::OnSpeechEnd() { - LOG(INFO) << "Server: Recieved speech end signal"; + LOG(INFO) << "Server: Received speech end signal"; if (recognizer_ != nullptr) { recognizer_->SetFinished(); } @@ -70,8 +70,8 @@ void ConnectionHandler::OnSpeechData(const beast::flat_buffer& buffer) { pcm_data(i) = static_cast(*pdata); pdata++; } - VLOG(2) << "Server: Recieved " << num_samples << " samples"; - LOG(INFO) << "Server: Recieved " << num_samples << " samples"; + VLOG(2) << "Server: Received " << num_samples << " samples"; + LOG(INFO) << "Server: Received " << num_samples << " samples"; CHECK(recognizer_ != nullptr); recognizer_->Accept(pcm_data); diff --git a/speechx/speechx/protocol/websocket/websocket_server.h b/runtime/engine/asr/server/websocket/websocket_server.h similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_server.h rename to runtime/engine/asr/server/websocket/websocket_server.h diff --git a/speechx/speechx/protocol/websocket/websocket_server_main.cc b/runtime/engine/asr/server/websocket/websocket_server_main.cc similarity index 100% rename from speechx/speechx/protocol/websocket/websocket_server_main.cc rename to runtime/engine/asr/server/websocket/websocket_server_main.cc diff --git a/runtime/engine/audio_classification/CMakeLists.txt b/runtime/engine/audio_classification/CMakeLists.txt new file mode 100644 index 00000000..52f1efef --- /dev/null +++ b/runtime/engine/audio_classification/CMakeLists.txt @@ -0,0 +1,3 @@ +# add_definitions("-DUSE_PADDLE_INFERENCE_BACKEND") +add_definitions("-DUSE_ORT_BACKEND") +add_subdirectory(nnet) \ No newline at end of file diff --git a/runtime/engine/audio_classification/nnet/CMakeLists.txt b/runtime/engine/audio_classification/nnet/CMakeLists.txt new file mode 100644 index 00000000..bb7f8eec --- /dev/null +++ b/runtime/engine/audio_classification/nnet/CMakeLists.txt @@ -0,0 +1,11 @@ +set(srcs + panns_nnet.cc + panns_interface.cc +) + +add_library(cls SHARED ${srcs}) +target_link_libraries(cls PRIVATE ${FASTDEPLOY_LIBS} kaldi-matrix kaldi-base frontend utils ) + +set(bin_name panns_nnet_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_link_libraries(${bin_name} gflags glog cls) diff --git a/runtime/engine/audio_classification/nnet/panns_interface.cc b/runtime/engine/audio_classification/nnet/panns_interface.cc new file mode 100644 index 00000000..d8b6a8b6 --- /dev/null +++ b/runtime/engine/audio_classification/nnet/panns_interface.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "audio_classification/nnet/panns_interface.h" + +#include "audio_classification/nnet/panns_nnet.h" +#include "common/base/config.h" + +namespace ppspeech { + +void* ClsCreateInstance(const char* conf_path) { + Config conf(conf_path); + // cls init + ppspeech::ClsNnetConf cls_nnet_conf; + cls_nnet_conf.wav_normal_ = conf.Read("wav_normal", true); + cls_nnet_conf.wav_normal_type_ = + conf.Read("wav_normal_type", std::string("linear")); + cls_nnet_conf.wav_norm_mul_factor_ = conf.Read("wav_norm_mul_factor", 1.0); + cls_nnet_conf.model_file_path_ = conf.Read("model_path", std::string("")); + cls_nnet_conf.param_file_path_ = conf.Read("param_path", std::string("")); + cls_nnet_conf.dict_file_path_ = conf.Read("dict_path", std::string("")); + cls_nnet_conf.num_cpu_thread_ = conf.Read("num_cpu_thread", 12); + cls_nnet_conf.samp_freq = conf.Read("samp_freq", 32000); + cls_nnet_conf.frame_length_ms = conf.Read("frame_length_ms", 32); + cls_nnet_conf.frame_shift_ms = conf.Read("frame_shift_ms", 10); + cls_nnet_conf.num_bins = conf.Read("num_bins", 64); + cls_nnet_conf.low_freq = conf.Read("low_freq", 50); + cls_nnet_conf.high_freq = conf.Read("high_freq", 14000); + cls_nnet_conf.dither = conf.Read("dither", 0.0); + + ppspeech::ClsNnet* cls_model = new ppspeech::ClsNnet(); + int ret = cls_model->Init(cls_nnet_conf); + return static_cast(cls_model); +} + +int ClsDestroyInstance(void* instance) { + ppspeech::ClsNnet* cls_model = static_cast(instance); + if (cls_model != NULL) { + delete cls_model; + cls_model = NULL; + } + return 0; +} + +int ClsFeedForward(void* instance, + const char* wav_path, + int topk, + char* result, + int result_max_len) { + ppspeech::ClsNnet* cls_model = static_cast(instance); + if (cls_model == NULL) { + printf("instance is null\n"); + return -1; + } + int ret = cls_model->Forward(wav_path, topk, result, result_max_len); + return 0; +} + +int ClsReset(void* instance) { + ppspeech::ClsNnet* cls_model = static_cast(instance); + if (cls_model == NULL) { + printf("instance is null\n"); + return -1; + } + cls_model->Reset(); + return 0; +} +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/audio_classification/nnet/panns_interface.h b/runtime/engine/audio_classification/nnet/panns_interface.h new file mode 100644 index 00000000..0d1ce95f --- /dev/null +++ b/runtime/engine/audio_classification/nnet/panns_interface.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace ppspeech { + +void* ClsCreateInstance(const char* conf_path); +int ClsDestroyInstance(void* instance); +int ClsFeedForward(void* instance, + const char* wav_path, + int topk, + char* result, + int result_max_len); +int ClsReset(void* instance); +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/audio_classification/nnet/panns_nnet.cc b/runtime/engine/audio_classification/nnet/panns_nnet.cc new file mode 100644 index 00000000..37ba74f9 --- /dev/null +++ b/runtime/engine/audio_classification/nnet/panns_nnet.cc @@ -0,0 +1,227 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "audio_classification/nnet/panns_nnet.h" +#ifdef WITH_PROFILING +#include "kaldi/base/timer.h" +#endif + +namespace ppspeech { + +ClsNnet::ClsNnet() { + // wav_reader_ = NULL; + runtime_ = NULL; +} + +void ClsNnet::Reset() { + // wav_reader_->Clear(); + ss_.str(""); +} + +int ClsNnet::Init(const ClsNnetConf& conf) { + conf_ = conf; + // init fbank opts + fbank_opts_.frame_opts.samp_freq = conf.samp_freq; + fbank_opts_.frame_opts.frame_length_ms = conf.frame_length_ms; + fbank_opts_.frame_opts.frame_shift_ms = conf.frame_shift_ms; + fbank_opts_.mel_opts.num_bins = conf.num_bins; + fbank_opts_.mel_opts.low_freq = conf.low_freq; + fbank_opts_.mel_opts.high_freq = conf.high_freq; + fbank_opts_.frame_opts.dither = conf.dither; + fbank_opts_.use_log_fbank = false; + + // init dict + if (conf.dict_file_path_ != "") { + ReadFileToVector(conf.dict_file_path_, &dict_); + } + + // init model + fastdeploy::RuntimeOption runtime_option; + +#ifdef USE_PADDLE_INFERENCE_BACKEND + runtime_option.SetModelPath(conf.model_file_path_, + conf.param_file_path_, + fastdeploy::ModelFormat::PADDLE); + runtime_option.UsePaddleInferBackend(); +#elif defined(USE_ORT_BACKEND) + runtime_option.SetModelPath( + conf.model_file_path_, "", fastdeploy::ModelFormat::ONNX); // onnx + runtime_option.UseOrtBackend(); // onnx +#elif defined(USE_PADDLE_LITE_BACKEND) + runtime_option.SetModelPath(conf.model_file_path_, + conf.param_file_path_, + fastdeploy::ModelFormat::PADDLE); + runtime_option.UseLiteBackend(); +#endif + + runtime_option.SetCpuThreadNum(conf.num_cpu_thread_); + // runtime_option.DeletePaddleBackendPass("simplify_with_basic_ops_pass"); + runtime_ = std::unique_ptr(new fastdeploy::Runtime()); + if (!runtime_->Init(runtime_option)) { + std::cerr << "--- Init FastDeploy Runitme Failed! " + << "\n--- Model: " << conf.model_file_path_ << std::endl; + return -1; + } else { + std::cout << "--- Init FastDeploy Runitme Done! " + << "\n--- Model: " << conf.model_file_path_ << std::endl; + } + + Reset(); + return 0; +} + +int ClsNnet::Forward(const char* wav_path, + int topk, + char* result, + int result_max_len) { +#ifdef WITH_PROFILING + kaldi::Timer timer; + timer.Reset(); +#endif + // read wav + std::ifstream infile(wav_path, std::ifstream::in); + kaldi::WaveData wave_data; + wave_data.Read(infile); + int32 this_channel = 0; + kaldi::Matrix wavform_kaldi = wave_data.Data(); + // only get channel 0 + int wavform_len = wavform_kaldi.NumCols(); + std::vector wavform(wavform_kaldi.Data(), + wavform_kaldi.Data() + wavform_len); + WaveformFloatNormal(&wavform); + WaveformNormal(&wavform, + conf_.wav_normal_, + conf_.wav_normal_type_, + conf_.wav_norm_mul_factor_); +#ifdef PPS_DEBUG + { + std::ofstream fp("cls.wavform", std::ios::out); + for (int i = 0; i < wavform.size(); ++i) { + fp << std::setprecision(18) << wavform[i] << " "; + } + fp << "\n"; + } +#endif +#ifdef WITH_PROFILING + printf("wav read consume: %fs\n", timer.Elapsed()); +#endif + +#ifdef WITH_PROFILING + timer.Reset(); +#endif + + std::vector feats; + std::unique_ptr data_source( + new ppspeech::DataCache()); + ppspeech::Fbank fbank(fbank_opts_, std::move(data_source)); + fbank.Accept(wavform); + fbank.SetFinished(); + fbank.Read(&feats); + + int feat_dim = fbank_opts_.mel_opts.num_bins; + int num_frames = feats.size() / feat_dim; + + for (int i = 0; i < num_frames; ++i) { + for (int j = 0; j < feat_dim; ++j) { + feats[i * feat_dim + j] = PowerTodb(feats[i * feat_dim + j]); + } + } +#ifdef PPS_DEBUG + { + std::ofstream fp("cls.feat", std::ios::out); + for (int i = 0; i < num_frames; ++i) { + for (int j = 0; j < feat_dim; ++j) { + fp << std::setprecision(18) << feats[i * feat_dim + j] << " "; + } + fp << "\n"; + } + } +#endif +#ifdef WITH_PROFILING + printf("extract fbank consume: %fs\n", timer.Elapsed()); +#endif + + // infer + std::vector model_out; +#ifdef WITH_PROFILING + timer.Reset(); +#endif + ModelForward(feats.data(), num_frames, feat_dim, &model_out); +#ifdef WITH_PROFILING + printf("fast deploy infer consume: %fs\n", timer.Elapsed()); +#endif +#ifdef PPS_DEBUG + { + std::ofstream fp("cls.logits", std::ios::out); + for (int i = 0; i < model_out.size(); ++i) { + fp << std::setprecision(18) << model_out[i] << "\n"; + } + } +#endif + + // construct result str + ss_ << "{"; + GetTopkResult(topk, model_out); + ss_ << "}"; + + if (result_max_len <= ss_.str().size()) { + printf("result_max_len is short than result len\n"); + } + snprintf(result, result_max_len, "%s", ss_.str().c_str()); + return 0; +} + +int ClsNnet::ModelForward(float* features, + const int num_frames, + const int feat_dim, + std::vector* model_out) { + // init input tensor shape + fastdeploy::TensorInfo info = runtime_->GetInputInfo(0); + info.shape = {1, num_frames, feat_dim}; + + std::vector input_tensors(1); + std::vector output_tensors(1); + + input_tensors[0].SetExternalData({1, num_frames, feat_dim}, + fastdeploy::FDDataType::FP32, + static_cast(features)); + + // get input name + input_tensors[0].name = info.name; + + runtime_->Infer(input_tensors, &output_tensors); + + // output_tensors[0].PrintInfo(); + std::vector output_shape = output_tensors[0].Shape(); + model_out->resize(output_shape[0] * output_shape[1]); + memcpy(static_cast(model_out->data()), + output_tensors[0].Data(), + output_shape[0] * output_shape[1] * sizeof(float)); + return 0; +} + +int ClsNnet::GetTopkResult(int k, const std::vector& model_out) { + std::vector values; + std::vector indics; + TopK(model_out, k, &values, &indics); + for (int i = 0; i < k; ++i) { + if (i != 0) { + ss_ << ","; + } + ss_ << "\"" << dict_[indics[i]] << "\":\"" << values[i] << "\""; + } + return 0; +} + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/audio_classification/nnet/panns_nnet.h b/runtime/engine/audio_classification/nnet/panns_nnet.h new file mode 100644 index 00000000..3a4a5718 --- /dev/null +++ b/runtime/engine/audio_classification/nnet/panns_nnet.h @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "common/frontend/data_cache.h" +#include "common/frontend/fbank.h" +#include "common/frontend/feature-fbank.h" +#include "common/frontend/frontend_itf.h" +#include "common/frontend/wave-reader.h" +#include "common/utils/audio_process.h" +#include "common/utils/file_utils.h" +#include "fastdeploy/runtime.h" +#include "kaldi/util/kaldi-io.h" +#include "kaldi/util/table-types.h" + +namespace ppspeech { +struct ClsNnetConf { + // wav + bool wav_normal_; + std::string wav_normal_type_; + float wav_norm_mul_factor_; + // model + std::string model_file_path_; + std::string param_file_path_; + std::string dict_file_path_; + int num_cpu_thread_; + // fbank + float samp_freq; + float frame_length_ms; + float frame_shift_ms; + int num_bins; + float low_freq; + float high_freq; + float dither; +}; + +class ClsNnet { + public: + ClsNnet(); + int Init(const ClsNnetConf& conf); + int Forward(const char* wav_path, + int topk, + char* result, + int result_max_len); + void Reset(); + + private: + int ModelForward(float* features, + const int num_frames, + const int feat_dim, + std::vector* model_out); + int ModelForwardStream(std::vector* feats); + int GetTopkResult(int k, const std::vector& model_out); + + ClsNnetConf conf_; + knf::FbankOptions fbank_opts_; + std::unique_ptr runtime_; + std::vector dict_; + std::stringstream ss_; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/audio_classification/nnet/panns_nnet_main.cc b/runtime/engine/audio_classification/nnet/panns_nnet_main.cc new file mode 100644 index 00000000..b47753f0 --- /dev/null +++ b/runtime/engine/audio_classification/nnet/panns_nnet_main.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "audio_classification/nnet/panns_interface.h" + +DEFINE_string(conf_path, "", "config path"); +DEFINE_string(scp_path, "", "wav scp path"); +DEFINE_string(topk, "", "print topk results"); + +int main(int argc, char* argv[]) { + gflags::SetUsageMessage("Usage:"); + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + FLAGS_logtostderr = 1; + CHECK_GT(FLAGS_conf_path.size(), 0); + CHECK_GT(FLAGS_scp_path.size(), 0); + CHECK_GT(FLAGS_topk.size(), 0); + void* instance = ppspeech::ClsCreateInstance(FLAGS_conf_path.c_str()); + int ret = 0; + // read wav + std::ifstream ifs(FLAGS_scp_path); + std::string line = ""; + int topk = std::atoi(FLAGS_topk.c_str()); + while (getline(ifs, line)) { + // read wav + char result[1024] = {0}; + ret = ppspeech::ClsFeedForward( + instance, line.c_str(), topk, result, 1024); + printf("%s %s\n", line.c_str(), result); + ret = ppspeech::ClsReset(instance); + } + ret = ppspeech::ClsDestroyInstance(instance); + return 0; +} diff --git a/runtime/engine/codelab/CMakeLists.txt b/runtime/engine/codelab/CMakeLists.txt new file mode 100644 index 00000000..13aa5efb --- /dev/null +++ b/runtime/engine/codelab/CMakeLists.txt @@ -0,0 +1,6 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +if(ANDROID) +else() #Unix + add_subdirectory(glog) +endif() \ No newline at end of file diff --git a/speechx/speechx/codelab/README.md b/runtime/engine/codelab/README.md similarity index 100% rename from speechx/speechx/codelab/README.md rename to runtime/engine/codelab/README.md diff --git a/speechx/speechx/codelab/glog/CMakeLists.txt b/runtime/engine/codelab/glog/CMakeLists.txt similarity index 67% rename from speechx/speechx/codelab/glog/CMakeLists.txt rename to runtime/engine/codelab/glog/CMakeLists.txt index 08a98641..492e33c6 100644 --- a/speechx/speechx/codelab/glog/CMakeLists.txt +++ b/runtime/engine/codelab/glog/CMakeLists.txt @@ -1,8 +1,8 @@ cmake_minimum_required(VERSION 3.14 FATAL_ERROR) add_executable(glog_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_main.cc) -target_link_libraries(glog_main glog) +target_link_libraries(glog_main extern_glog) add_executable(glog_logtostderr_main ${CMAKE_CURRENT_SOURCE_DIR}/glog_logtostderr_main.cc) -target_link_libraries(glog_logtostderr_main glog) +target_link_libraries(glog_logtostderr_main extern_glog) diff --git a/speechx/speechx/codelab/glog/README.md b/runtime/engine/codelab/glog/README.md similarity index 100% rename from speechx/speechx/codelab/glog/README.md rename to runtime/engine/codelab/glog/README.md diff --git a/speechx/speechx/codelab/glog/glog_logtostderr_main.cc b/runtime/engine/codelab/glog/glog_logtostderr_main.cc similarity index 100% rename from speechx/speechx/codelab/glog/glog_logtostderr_main.cc rename to runtime/engine/codelab/glog/glog_logtostderr_main.cc diff --git a/speechx/speechx/codelab/glog/glog_main.cc b/runtime/engine/codelab/glog/glog_main.cc similarity index 100% rename from speechx/speechx/codelab/glog/glog_main.cc rename to runtime/engine/codelab/glog/glog_main.cc diff --git a/runtime/engine/common/CMakeLists.txt b/runtime/engine/common/CMakeLists.txt new file mode 100644 index 00000000..405479ae --- /dev/null +++ b/runtime/engine/common/CMakeLists.txt @@ -0,0 +1,19 @@ +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/../ +) +add_subdirectory(base) +add_subdirectory(utils) +add_subdirectory(matrix) + +include_directories( +${CMAKE_CURRENT_SOURCE_DIR}/frontend +) +add_subdirectory(frontend) + +add_library(common INTERFACE) +target_link_libraries(common INTERFACE base utils kaldi-matrix frontend) +install(TARGETS base DESTINATION lib) +install(TARGETS utils DESTINATION lib) +install(TARGETS kaldi-matrix DESTINATION lib) +install(TARGETS frontend DESTINATION lib) \ No newline at end of file diff --git a/runtime/engine/common/base/CMakeLists.txt b/runtime/engine/common/base/CMakeLists.txt new file mode 100644 index 00000000..b17131b5 --- /dev/null +++ b/runtime/engine/common/base/CMakeLists.txt @@ -0,0 +1,43 @@ + + +if(WITH_ASR) + add_compile_options(-DWITH_ASR) + set(PPS_FLAGS_LIB "fst/flags.h") +else() + set(PPS_FLAGS_LIB "gflags/gflags.h") +endif() + +if(ANDROID) + set(PPS_GLOG_LIB "base/log_impl.h") +else() #UNIX + if(WITH_ASR) + set(PPS_GLOG_LIB "fst/log.h") + else() + set(PPS_GLOG_LIB "glog/logging.h") + endif() +endif() + +configure_file( + ${CMAKE_CURRENT_SOURCE_DIR}/flags.h.in + ${CMAKE_CURRENT_SOURCE_DIR}/flags.h @ONLY + ) +message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/flags.h") + +configure_file( + ${CMAKE_CURRENT_SOURCE_DIR}/log.h.in + ${CMAKE_CURRENT_SOURCE_DIR}/log.h @ONLY + ) +message(STATUS "Generated ${CMAKE_CURRENT_SOURCE_DIR}/log.h") + + +if(ANDROID) + set(csrc + log_impl.cc + glog_utils.cc + ) + add_library(base ${csrc}) + target_link_libraries(base gflags) +else() # UNIX + set(csrc) + add_library(base INTERFACE) +endif() \ No newline at end of file diff --git a/speechx/speechx/base/basic_types.h b/runtime/engine/common/base/basic_types.h similarity index 100% rename from speechx/speechx/base/basic_types.h rename to runtime/engine/common/base/basic_types.h diff --git a/speechx/speechx/base/common.h b/runtime/engine/common/base/common.h similarity index 93% rename from speechx/speechx/base/common.h rename to runtime/engine/common/base/common.h index 97bff966..b31fc53e 100644 --- a/speechx/speechx/base/common.h +++ b/runtime/engine/common/base/common.h @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include #include @@ -48,4 +50,5 @@ #include "base/log.h" #include "base/macros.h" #include "utils/file_utils.h" -#include "utils/math.h" \ No newline at end of file +#include "utils/math.h" +#include "utils/timer.h" \ No newline at end of file diff --git a/runtime/engine/common/base/config.h b/runtime/engine/common/base/config.h new file mode 100644 index 00000000..c8eae5e2 --- /dev/null +++ b/runtime/engine/common/base/config.h @@ -0,0 +1,343 @@ +// Copyright (c) code is from +// https://blog.csdn.net/huixingshao/article/details/45969887. + +#include +#include +#include +#include +#include +using namespace std; + +#pragma once + +#ifdef _MSC_VER +#pragma region ParseIniFile +#endif + +/* + * \brief Generic configuration Class + * + */ +class Config { + // Data + protected: + std::string m_Delimiter; //!< separator between key and value + std::string m_Comment; //!< separator between value and comments + std::map + m_Contents; //!< extracted keys and values + + typedef std::map::iterator mapi; + typedef std::map::const_iterator mapci; + // Methods + public: + Config(std::string filename, + std::string delimiter = "=", + std::string comment = "#"); + Config(); + template + T Read(const std::string& in_key) const; //!< Search for key and read value + //! or optional default value, call + //! as read + template + T Read(const std::string& in_key, const T& in_value) const; + template + bool ReadInto(T* out_var, const std::string& in_key) const; + template + bool ReadInto(T* out_var, + const std::string& in_key, + const T& in_value) const; + bool FileExist(std::string filename); + void ReadFile(std::string filename, + std::string delimiter = "=", + std::string comment = "#"); + + // Check whether key exists in configuration + bool KeyExists(const std::string& in_key) const; + + // Modify keys and values + template + void Add(const std::string& in_key, const T& in_value); + void Remove(const std::string& in_key); + + // Check or change configuration syntax + std::string GetDelimiter() const { return m_Delimiter; } + std::string GetComment() const { return m_Comment; } + std::string SetDelimiter(const std::string& in_s) { + std::string old = m_Delimiter; + m_Delimiter = in_s; + return old; + } + std::string SetComment(const std::string& in_s) { + std::string old = m_Comment; + m_Comment = in_s; + return old; + } + + // Write or read configuration + friend std::ostream& operator<<(std::ostream& os, const Config& cf); + friend std::istream& operator>>(std::istream& is, Config& cf); + + protected: + template + static std::string T_as_string(const T& t); + template + static T string_as_T(const std::string& s); + static void Trim(std::string* inout_s); + + + // Exception types + public: + struct File_not_found { + std::string filename; + explicit File_not_found(const std::string& filename_ = std::string()) + : filename(filename_) {} + }; + struct Key_not_found { // thrown only by T read(key) variant of read() + std::string key; + explicit Key_not_found(const std::string& key_ = std::string()) + : key(key_) {} + }; +}; + +/* static */ +template +std::string Config::T_as_string(const T& t) { + // Convert from a T to a string + // Type T must support << operator + std::ostringstream ost; + ost << t; + return ost.str(); +} + + +/* static */ +template +T Config::string_as_T(const std::string& s) { + // Convert from a string to a T + // Type T must support >> operator + T t; + std::istringstream ist(s); + ist >> t; + return t; +} + + +/* static */ +template <> +inline std::string Config::string_as_T(const std::string& s) { + // Convert from a string to a string + // In other words, do nothing + return s; +} + + +/* static */ +template <> +inline bool Config::string_as_T(const std::string& s) { + // Convert from a string to a bool + // Interpret "false", "F", "no", "n", "0" as false + // Interpret "true", "T", "yes", "y", "1", "-1", or anything else as true + bool b = true; + std::string sup = s; + for (std::string::iterator p = sup.begin(); p != sup.end(); ++p) + *p = toupper(*p); // make string all caps + if (sup == std::string("FALSE") || sup == std::string("F") || + sup == std::string("NO") || sup == std::string("N") || + sup == std::string("0") || sup == std::string("NONE")) + b = false; + return b; +} + + +template +T Config::Read(const std::string& key) const { + // Read the value corresponding to key + mapci p = m_Contents.find(key); + if (p == m_Contents.end()) throw Key_not_found(key); + return string_as_T(p->second); +} + + +template +T Config::Read(const std::string& key, const T& value) const { + // Return the value corresponding to key or given default value + // if key is not found + mapci p = m_Contents.find(key); + if (p == m_Contents.end()) { + printf("%s = %s(default)\n", key.c_str(), T_as_string(value).c_str()); + return value; + } else { + printf("%s = %s\n", key.c_str(), T_as_string(p->second).c_str()); + return string_as_T(p->second); + } +} + + +template +bool Config::ReadInto(T* var, const std::string& key) const { + // Get the value corresponding to key and store in var + // Return true if key is found + // Otherwise leave var untouched + mapci p = m_Contents.find(key); + bool found = (p != m_Contents.end()); + if (found) *var = string_as_T(p->second); + return found; +} + + +template +bool Config::ReadInto(T* var, const std::string& key, const T& value) const { + // Get the value corresponding to key and store in var + // Return true if key is found + // Otherwise set var to given default + mapci p = m_Contents.find(key); + bool found = (p != m_Contents.end()); + if (found) + *var = string_as_T(p->second); + else + var = value; + return found; +} + + +template +void Config::Add(const std::string& in_key, const T& value) { + // Add a key with given value + std::string v = T_as_string(value); + std::string key = in_key; + Trim(&key); + Trim(&v); + m_Contents[key] = v; + return; +} + +Config::Config(string filename, string delimiter, string comment) + : m_Delimiter(delimiter), m_Comment(comment) { + // Construct a Config, getting keys and values from given file + + std::ifstream in(filename.c_str()); + + if (!in) throw File_not_found(filename); + + in >> (*this); +} + + +Config::Config() : m_Delimiter(string(1, '=')), m_Comment(string(1, '#')) { + // Construct a Config without a file; empty +} + + +bool Config::KeyExists(const string& key) const { + // Indicate whether key is found + mapci p = m_Contents.find(key); + return (p != m_Contents.end()); +} + + +/* static */ +void Config::Trim(string* inout_s) { + // Remove leading and trailing whitespace + static const char whitespace[] = " \n\t\v\r\f"; + inout_s->erase(0, inout_s->find_first_not_of(whitespace)); + inout_s->erase(inout_s->find_last_not_of(whitespace) + 1U); +} + + +std::ostream& operator<<(std::ostream& os, const Config& cf) { + // Save a Config to os + for (Config::mapci p = cf.m_Contents.begin(); p != cf.m_Contents.end(); + ++p) { + os << p->first << " " << cf.m_Delimiter << " "; + os << p->second << std::endl; + } + return os; +} + +void Config::Remove(const string& key) { + // Remove key and its value + m_Contents.erase(m_Contents.find(key)); + return; +} + +std::istream& operator>>(std::istream& is, Config& cf) { + // Load a Config from is + // Read in keys and values, keeping internal whitespace + typedef string::size_type pos; + const string& delim = cf.m_Delimiter; // separator + const string& comm = cf.m_Comment; // comment + const pos skip = delim.length(); // length of separator + + string nextline = ""; // might need to read ahead to see where value ends + + while (is || nextline.length() > 0) { + // Read an entire line at a time + string line; + if (nextline.length() > 0) { + line = nextline; // we read ahead; use it now + nextline = ""; + } else { + std::getline(is, line); + } + + // Ignore comments + line = line.substr(0, line.find(comm)); + + // Parse the line if it contains a delimiter + pos delimPos = line.find(delim); + if (delimPos < string::npos) { + // Extract the key + string key = line.substr(0, delimPos); + line.replace(0, delimPos + skip, ""); + + // See if value continues on the next line + // Stop at blank line, next line with a key, end of stream, + // or end of file sentry + bool terminate = false; + while (!terminate && is) { + std::getline(is, nextline); + terminate = true; + + string nlcopy = nextline; + Config::Trim(&nlcopy); + if (nlcopy == "") continue; + + nextline = nextline.substr(0, nextline.find(comm)); + if (nextline.find(delim) != string::npos) continue; + + nlcopy = nextline; + Config::Trim(&nlcopy); + if (nlcopy != "") line += "\n"; + line += nextline; + terminate = false; + } + + // Store key and value + Config::Trim(&key); + Config::Trim(&line); + cf.m_Contents[key] = line; // overwrites if key is repeated + } + } + + return is; +} +bool Config::FileExist(std::string filename) { + bool exist = false; + std::ifstream in(filename.c_str()); + if (in) exist = true; + return exist; +} + +void Config::ReadFile(string filename, string delimiter, string comment) { + m_Delimiter = delimiter; + m_Comment = comment; + std::ifstream in(filename.c_str()); + + if (!in) throw File_not_found(filename); + + in >> (*this); +} + +#ifdef _MSC_VER +#pragma endregion ParseIniFIle +#endif diff --git a/speechx/speechx/base/log.h b/runtime/engine/common/base/flags.h.in similarity index 95% rename from speechx/speechx/base/log.h rename to runtime/engine/common/base/flags.h.in index c613b98c..fd265abc 100644 --- a/speechx/speechx/base/log.h +++ b/runtime/engine/common/base/flags.h.in @@ -14,4 +14,4 @@ #pragma once -#include "fst/log.h" +#include "@PPS_FLAGS_LIB@" diff --git a/runtime/engine/common/base/glog_utils.cc b/runtime/engine/common/base/glog_utils.cc new file mode 100644 index 00000000..4ab3c251 --- /dev/null +++ b/runtime/engine/common/base/glog_utils.cc @@ -0,0 +1,12 @@ + +#include "base/glog_utils.h" + +namespace google { +void InitGoogleLogging(const char* name) { + LOG(INFO) << "dummpy InitGoogleLogging."; +} + +void InstallFailureSignalHandler() { + LOG(INFO) << "dummpy InstallFailureSignalHandler."; +} +} // namespace google diff --git a/runtime/engine/common/base/glog_utils.h b/runtime/engine/common/base/glog_utils.h new file mode 100644 index 00000000..9cffcafb --- /dev/null +++ b/runtime/engine/common/base/glog_utils.h @@ -0,0 +1,9 @@ +#pragma once + +#include "base/common.h" + +namespace google { +void InitGoogleLogging(const char* name); + +void InstallFailureSignalHandler(); +} // namespace google \ No newline at end of file diff --git a/speechx/speechx/base/flags.h b/runtime/engine/common/base/log.h.in similarity index 96% rename from speechx/speechx/base/flags.h rename to runtime/engine/common/base/log.h.in index 41df0d45..5d121add 100644 --- a/speechx/speechx/base/flags.h +++ b/runtime/engine/common/base/log.h.in @@ -14,4 +14,4 @@ #pragma once -#include "fst/flags.h" +#include "@PPS_GLOG_LIB@" diff --git a/runtime/engine/common/base/log_impl.cc b/runtime/engine/common/base/log_impl.cc new file mode 100644 index 00000000..d8295590 --- /dev/null +++ b/runtime/engine/common/base/log_impl.cc @@ -0,0 +1,105 @@ +#include "base/log.h" + +DEFINE_int32(logtostderr, 0, "logging to stderr"); + +namespace ppspeech { + +static char __progname[] = "paddlespeech"; + +namespace log { + +std::mutex LogMessage::lock_; +std::string LogMessage::s_debug_logfile_(""); +std::string LogMessage::s_info_logfile_(""); +std::string LogMessage::s_warning_logfile_(""); +std::string LogMessage::s_error_logfile_(""); +std::string LogMessage::s_fatal_logfile_(""); + +void LogMessage::get_curr_proc_info(std::string* pid, std::string* proc_name) { + std::stringstream ss; + ss << getpid(); + ss >> *pid; + *proc_name = ::ppspeech::__progname; +} + +LogMessage::LogMessage(const char* file, + int line, + Severity level, + bool verbose, + bool out_to_file /* = false */) + : level_(level), verbose_(verbose), out_to_file_(out_to_file) { + if (FLAGS_logtostderr == 0) { + stream_ = static_cast(&std::cout); + } else if (FLAGS_logtostderr == 1) { + stream_ = static_cast(&std::cerr); + } else if (out_to_file_) { + // logfile + lock_.lock(); + init(file, line); + } +} + +LogMessage::~LogMessage() { + stream() << std::endl; + + if (out_to_file_) { + lock_.unlock(); + } + + if (verbose_ && level_ == FATAL) { + std::abort(); + } +} + +std::ostream* LogMessage::nullstream() { + thread_local static std::ofstream os; + thread_local static bool flag_set = false; + if (!flag_set) { + os.setstate(std::ios_base::badbit); + flag_set = true; + } + return &os; +} + +void LogMessage::init(const char* file, int line) { + time_t t = time(0); + char tmp[100]; + strftime(tmp, sizeof(tmp), "%Y%m%d-%H%M%S", localtime(&t)); + + if (s_info_logfile_.empty()) { + std::string pid; + std::string proc_name; + get_curr_proc_info(&pid, &proc_name); + + s_debug_logfile_ = + std::string("log." + proc_name + ".log.DEBUG." + tmp + "." + pid); + s_info_logfile_ = + std::string("log." + proc_name + ".log.INFO." + tmp + "." + pid); + s_warning_logfile_ = + std::string("log." + proc_name + ".log.WARNING." + tmp + "." + pid); + s_error_logfile_ = + std::string("log." + proc_name + ".log.ERROR." + tmp + "." + pid); + s_fatal_logfile_ = + std::string("log." + proc_name + ".log.FATAL." + tmp + "." + pid); + } + + thread_local static std::ofstream ofs; + if (level_ == DEBUG) { + ofs.open(s_debug_logfile_.c_str(), std::ios::out | std::ios::app); + } else if (level_ == INFO) { + ofs.open(s_info_logfile_.c_str(), std::ios::out | std::ios::app); + } else if (level_ == WARNING) { + ofs.open(s_warning_logfile_.c_str(), std::ios::out | std::ios::app); + } else if (level_ == ERROR) { + ofs.open(s_error_logfile_.c_str(), std::ios::out | std::ios::app); + } else { + ofs.open(s_fatal_logfile_.c_str(), std::ios::out | std::ios::app); + } + + stream_ = &ofs; + + stream() << tmp << " " << file << " line " << line << "; "; + stream() << std::flush; +} +} // namespace log +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/common/base/log_impl.h b/runtime/engine/common/base/log_impl.h new file mode 100644 index 00000000..fd6cce19 --- /dev/null +++ b/runtime/engine/common/base/log_impl.h @@ -0,0 +1,173 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// modified from https://github.com/Dounm/dlog +// modified form +// https://android.googlesource.com/platform/art/+/806defa/src/logging.h + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "base/common.h" +#include "base/macros.h" +#ifndef WITH_GLOG +#include "base/glog_utils.h" +#endif + +DECLARE_int32(logtostderr); + +namespace ppspeech { + +namespace log { + +enum Severity { + DEBUG, + INFO, + WARNING, + ERROR, + FATAL, + NUM_SEVERITIES, +}; + +class LogMessage { + public: + static void get_curr_proc_info(std::string* pid, std::string* proc_name); + + LogMessage(const char* file, + int line, + Severity level, + bool verbose, + bool out_to_file = false); + + ~LogMessage(); + + std::ostream& stream() { return verbose_ ? *stream_ : *nullstream(); } + + private: + void init(const char* file, int line); + std::ostream* nullstream(); + + private: + std::ostream* stream_; + std::ostream* null_stream_; + Severity level_; + bool verbose_; + bool out_to_file_; + + static std::mutex lock_; // stream write lock + static std::string s_debug_logfile_; + static std::string s_info_logfile_; + static std::string s_warning_logfile_; + static std::string s_error_logfile_; + static std::string s_fatal_logfile_; + + DISALLOW_COPY_AND_ASSIGN(LogMessage); +}; + + +} // namespace log + +} // namespace ppspeech + + +#ifndef PPS_DEBUG +#define DLOG_INFO \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, false) +#define DLOG_WARNING \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, false) +#define DLOG_ERROR \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, false) +#define DLOG_FATAL \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, false) +#else +#define DLOG_INFO \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, true) +#define DLOG_WARNING \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, true) +#define DLOG_ERROR \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, true) +#define DLOG_FATAL \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, true) +#endif + + +#define LOG_INFO \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::INFO, true) +#define LOG_WARNING \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::WARNING, true) +#define LOG_ERROR \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::ERROR, true) +#define LOG_FATAL \ + ppspeech::log::LogMessage(__FILE__, __LINE__, ppspeech::log::FATAL, true) + + +#define LOG_0 LOG_DEBUG +#define LOG_1 LOG_INFO +#define LOG_2 LOG_WARNING +#define LOG_3 LOG_ERROR +#define LOG_4 LOG_FATAL + +#define LOG(level) LOG_##level.stream() + +#define DLOG(level) DLOG_##level.stream() + +#define VLOG(verboselevel) LOG(verboselevel) + +#define CHECK(exp) \ + ppspeech::log::LogMessage( \ + __FILE__, __LINE__, ppspeech::log::FATAL, !(exp)) \ + .stream() \ + << "Check Failed: " #exp + +#define CHECK_EQ(x, y) CHECK((x) == (y)) +#define CHECK_NE(x, y) CHECK((x) != (y)) +#define CHECK_LE(x, y) CHECK((x) <= (y)) +#define CHECK_LT(x, y) CHECK((x) < (y)) +#define CHECK_GE(x, y) CHECK((x) >= (y)) +#define CHECK_GT(x, y) CHECK((x) > (y)) +#ifdef PPS_DEBUG +#define DCHECK(x) CHECK(x) +#define DCHECK_EQ(x, y) CHECK_EQ(x, y) +#define DCHECK_NE(x, y) CHECK_NE(x, y) +#define DCHECK_LE(x, y) CHECK_LE(x, y) +#define DCHECK_LT(x, y) CHECK_LT(x, y) +#define DCHECK_GE(x, y) CHECK_GE(x, y) +#define DCHECK_GT(x, y) CHECK_GT(x, y) +#else +#define DCHECK(condition) \ + while (false) CHECK(condition) +#define DCHECK_EQ(val1, val2) \ + while (false) CHECK_EQ(val1, val2) +#define DCHECK_NE(val1, val2) \ + while (false) CHECK_NE(val1, val2) +#define DCHECK_LE(val1, val2) \ + while (false) CHECK_LE(val1, val2) +#define DCHECK_LT(val1, val2) \ + while (false) CHECK_LT(val1, val2) +#define DCHECK_GE(val1, val2) \ + while (false) CHECK_GE(val1, val2) +#define DCHECK_GT(val1, val2) \ + while (false) CHECK_GT(val1, val2) +#define DCHECK_STREQ(str1, str2) \ + while (false) CHECK_STREQ(str1, str2) +#endif \ No newline at end of file diff --git a/speechx/speechx/base/macros.h b/runtime/engine/common/base/macros.h similarity index 100% rename from speechx/speechx/base/macros.h rename to runtime/engine/common/base/macros.h index db989812..e60baf55 100644 --- a/speechx/speechx/base/macros.h +++ b/runtime/engine/common/base/macros.h @@ -17,14 +17,14 @@ #include #include -namespace ppspeech { - #ifndef DISALLOW_COPY_AND_ASSIGN #define DISALLOW_COPY_AND_ASSIGN(TypeName) \ TypeName(const TypeName&) = delete; \ void operator=(const TypeName&) = delete #endif +namespace ppspeech { + // kSpaceSymbol in UTF-8 is: ▁ const char kSpaceSymbo[] = "\xe2\x96\x81"; diff --git a/runtime/engine/common/base/safe_queue.h b/runtime/engine/common/base/safe_queue.h new file mode 100644 index 00000000..25a012af --- /dev/null +++ b/runtime/engine/common/base/safe_queue.h @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/common.h" + +namespace ppspeech { + +template +class SafeQueue { + public: + explicit SafeQueue(size_t capacity = 0); + void push_back(const T& in); + bool pop(T* out); + bool empty() const { return buffer_.empty(); } + size_t size() const { return buffer_.size(); } + void clear(); + + + private: + std::mutex mutex_; + std::condition_variable condition_; + std::deque buffer_; + size_t capacity_; +}; + +template +SafeQueue::SafeQueue(size_t capacity) : capacity_(capacity) {} + +template +void SafeQueue::push_back(const T& in) { + std::unique_lock lock(mutex_); + if (capacity_ > 0 && buffer_.size() == capacity_) { + condition_.wait(lock, [this] { return capacity_ >= buffer_.size(); }); + } + + buffer_.push_back(in); + condition_.notify_one(); +} + +template +bool SafeQueue::pop(T* out) { + if (buffer_.empty()) { + return false; + } + + std::unique_lock lock(mutex_); + condition_.wait(lock, [this] { return buffer_.size() > 0; }); + *out = std::move(buffer_.front()); + buffer_.pop_front(); + condition_.notify_one(); + return true; +} + +template +void SafeQueue::clear() { + std::unique_lock lock(mutex_); + buffer_.clear(); + condition_.notify_one(); +} +} // namespace ppspeech diff --git a/speechx/speechx/frontend/text/CMakeLists.txt b/runtime/engine/common/base/safe_queue_inl.h similarity index 100% rename from speechx/speechx/frontend/text/CMakeLists.txt rename to runtime/engine/common/base/safe_queue_inl.h diff --git a/speechx/speechx/base/thread_pool.h b/runtime/engine/common/base/thread_pool.h similarity index 100% rename from speechx/speechx/base/thread_pool.h rename to runtime/engine/common/base/thread_pool.h diff --git a/runtime/engine/common/frontend/CMakeLists.txt b/runtime/engine/common/frontend/CMakeLists.txt new file mode 100644 index 00000000..0b95b650 --- /dev/null +++ b/runtime/engine/common/frontend/CMakeLists.txt @@ -0,0 +1,31 @@ +add_library(kaldi-native-fbank-core + feature-fbank.cc + feature-functions.cc + feature-window.cc + fftsg.c + mel-computations.cc + rfft.cc +) +target_link_libraries(kaldi-native-fbank-core PUBLIC utils base) +target_compile_options(kaldi-native-fbank-core PUBLIC "-fPIC") + +add_library(frontend STATIC + cmvn.cc + audio_cache.cc + feature_cache.cc + feature_pipeline.cc + assembler.cc + wave-reader.cc +) +target_link_libraries(frontend PUBLIC kaldi-native-fbank-core utils base) + +set(BINS + compute_fbank_main +) + +foreach(bin_name IN LISTS BINS) + add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) + target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi) + # https://github.com/Kitware/CMake/blob/v3.1.0/Modules/FindThreads.cmake#L207 + target_link_libraries(${bin_name} PUBLIC frontend base utils kaldi-util libgflags_nothreads.so Threads::Threads extern_glog) +endforeach() diff --git a/speechx/speechx/frontend/audio/assembler.cc b/runtime/engine/common/frontend/assembler.cc similarity index 75% rename from speechx/speechx/frontend/audio/assembler.cc rename to runtime/engine/common/frontend/assembler.cc index 9d5fc403..ba46e1ca 100644 --- a/speechx/speechx/frontend/audio/assembler.cc +++ b/runtime/engine/common/frontend/assembler.cc @@ -12,14 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "frontend/audio/assembler.h" +#include "frontend/assembler.h" namespace ppspeech { using kaldi::BaseFloat; -using kaldi::Vector; -using kaldi::VectorBase; using std::unique_ptr; +using std::vector; Assembler::Assembler(AssemblerOptions opts, unique_ptr base_extractor) { @@ -33,13 +32,13 @@ Assembler::Assembler(AssemblerOptions opts, dim_ = base_extractor_->Dim(); } -void Assembler::Accept(const kaldi::VectorBase& inputs) { +void Assembler::Accept(const std::vector& inputs) { // read inputs base_extractor_->Accept(inputs); } // pop feature chunk -bool Assembler::Read(kaldi::Vector* feats) { +bool Assembler::Read(std::vector* feats) { kaldi::Timer timer; bool result = Compute(feats); VLOG(1) << "Assembler::Read cost: " << timer.Elapsed() << " sec."; @@ -47,40 +46,37 @@ bool Assembler::Read(kaldi::Vector* feats) { } // read frame by frame from base_feature_extractor_ into cache_ -bool Assembler::Compute(Vector* feats) { +bool Assembler::Compute(vector* feats) { // compute and feed frame by frame while (feature_cache_.size() < frame_chunk_size_) { - Vector feature; + vector feature; bool result = base_extractor_->Read(&feature); - if (result == false || feature.Dim() == 0) { - VLOG(3) << "result: " << result - << " feature dim: " << feature.Dim(); + if (result == false || feature.size() == 0) { + VLOG(1) << "result: " << result + << " feature dim: " << feature.size(); if (IsFinished() == false) { - VLOG(3) << "finished reading feature. cache size: " + VLOG(1) << "finished reading feature. cache size: " << feature_cache_.size(); return false; } else { - VLOG(3) << "break"; + VLOG(1) << "break"; break; } } - - CHECK(feature.Dim() == dim_); feature_cache_.push(feature); - nframes_ += 1; - VLOG(3) << "nframes: " << nframes_; + VLOG(1) << "nframes: " << nframes_; } if (feature_cache_.size() < receptive_filed_length_) { - VLOG(3) << "feature_cache less than receptive_filed_lenght. " + VLOG(3) << "feature_cache less than receptive_filed_length. " << feature_cache_.size() << ": " << receptive_filed_length_; return false; } if (fill_zero_) { while (feature_cache_.size() < frame_chunk_size_) { - Vector feature(dim_, kaldi::kSetZero); + vector feature(dim_, kaldi::kSetZero); nframes_ += 1; feature_cache_.push(feature); } @@ -88,16 +84,17 @@ bool Assembler::Compute(Vector* feats) { int32 this_chunk_size = std::min(static_cast(feature_cache_.size()), frame_chunk_size_); - feats->Resize(dim_ * this_chunk_size); + feats->resize(dim_ * this_chunk_size); VLOG(3) << "read " << this_chunk_size << " feat."; int32 counter = 0; while (counter < this_chunk_size) { - Vector& val = feature_cache_.front(); - CHECK(val.Dim() == dim_) << val.Dim(); + vector& val = feature_cache_.front(); + CHECK(val.size() == dim_) << val.size(); int32 start = counter * dim_; - feats->Range(start, dim_).CopyFromVec(val); + std::memcpy( + feats->data() + start, val.data(), val.size() * sizeof(BaseFloat)); if (this_chunk_size - counter <= cache_size_) { feature_cache_.push(val); @@ -115,7 +112,7 @@ bool Assembler::Compute(Vector* feats) { void Assembler::Reset() { - std::queue> empty; + std::queue> empty; std::swap(feature_cache_, empty); nframes_ = 0; base_extractor_->Reset(); diff --git a/speechx/speechx/frontend/audio/assembler.h b/runtime/engine/common/frontend/assembler.h similarity index 86% rename from speechx/speechx/frontend/audio/assembler.h rename to runtime/engine/common/frontend/assembler.h index 72e6f635..9ec28053 100644 --- a/speechx/speechx/frontend/audio/assembler.h +++ b/runtime/engine/common/frontend/assembler.h @@ -15,7 +15,7 @@ #pragma once #include "base/common.h" -#include "frontend/audio/frontend_itf.h" +#include "frontend/frontend_itf.h" namespace ppspeech { @@ -36,10 +36,10 @@ class Assembler : public FrontendInterface { std::unique_ptr base_extractor = NULL); // Feed feats or waves - void Accept(const kaldi::VectorBase& inputs) override; + void Accept(const std::vector& inputs) override; // feats size = num_frames * feat_dim - bool Read(kaldi::Vector* feats) override; + bool Read(std::vector* feats) override; // feat dim size_t Dim() const override { return dim_; } @@ -51,7 +51,7 @@ class Assembler : public FrontendInterface { void Reset() override; private: - bool Compute(kaldi::Vector* feats); + bool Compute(std::vector* feats); bool fill_zero_{false}; @@ -60,7 +60,7 @@ class Assembler : public FrontendInterface { int32 frame_chunk_stride_; // stride int32 cache_size_; // window - stride int32 receptive_filed_length_; - std::queue> feature_cache_; + std::queue> feature_cache_; std::unique_ptr base_extractor_; int32 nframes_; // num frame computed diff --git a/speechx/speechx/frontend/audio/audio_cache.cc b/runtime/engine/common/frontend/audio_cache.cc similarity index 63% rename from speechx/speechx/frontend/audio/audio_cache.cc rename to runtime/engine/common/frontend/audio_cache.cc index c6a91f4b..7ff1c4c4 100644 --- a/speechx/speechx/frontend/audio/audio_cache.cc +++ b/runtime/engine/common/frontend/audio_cache.cc @@ -12,15 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "frontend/audio/audio_cache.h" +#include "frontend/audio_cache.h" #include "kaldi/base/timer.h" namespace ppspeech { using kaldi::BaseFloat; -using kaldi::Vector; -using kaldi::VectorBase; +using std::vector; AudioCache::AudioCache(int buffer_size, bool to_float32) : finished_(false), @@ -37,53 +36,39 @@ BaseFloat AudioCache::Convert2PCM32(BaseFloat val) { return val * (1. / std::pow(2.0, 15)); } -void AudioCache::Accept(const VectorBase& waves) { +void AudioCache::Accept(const vector& waves) { kaldi::Timer timer; std::unique_lock lock(mutex_); - while (size_ + waves.Dim() > ring_buffer_.size()) { + while (size_ + waves.size() > ring_buffer_.size()) { ready_feed_condition_.wait(lock); } - for (size_t idx = 0; idx < waves.Dim(); ++idx) { + for (size_t idx = 0; idx < waves.size(); ++idx) { int32 buffer_idx = (idx + offset_ + size_) % ring_buffer_.size(); - ring_buffer_[buffer_idx] = waves(idx); - if (to_float32_) ring_buffer_[buffer_idx] = Convert2PCM32(waves(idx)); + ring_buffer_[buffer_idx] = waves[idx]; + if (to_float32_) ring_buffer_[buffer_idx] = Convert2PCM32(waves[idx]); } - size_ += waves.Dim(); + size_ += waves.size(); VLOG(1) << "AudioCache::Accept cost: " << timer.Elapsed() << " sec. " - << waves.Dim() << " samples."; + << waves.size() << " samples."; } -bool AudioCache::Read(Vector* waves) { +bool AudioCache::Read(vector* waves) { kaldi::Timer timer; - size_t chunk_size = waves->Dim(); + size_t chunk_size = waves->size(); std::unique_lock lock(mutex_); - while (chunk_size > size_) { - // when audio is empty and no more data feed - // ready_read_condition will block in dead lock, - // so replace with timeout_ - // ready_read_condition_.wait(lock); - int32 elapsed = static_cast(timer.Elapsed() * 1000); - if (elapsed > timeout_) { - if (finished_ == true) { - // read last chunk data - break; - } - if (chunk_size > size_) { - return false; - } - } - usleep(100); // sleep 0.1 ms - } - - // read last chunk data if (chunk_size > size_) { - chunk_size = size_; - waves->Resize(chunk_size); + if (finished_ == false) { + return false; + } else { + // read last chunk data + chunk_size = size_; + waves->resize(chunk_size); + } } for (size_t idx = 0; idx < chunk_size; ++idx) { int buff_idx = (offset_ + idx) % ring_buffer_.size(); - waves->Data()[idx] = ring_buffer_[buff_idx]; + waves->at(idx) = ring_buffer_[buff_idx]; } size_ -= chunk_size; offset_ = (offset_ + chunk_size) % ring_buffer_.size(); diff --git a/speechx/speechx/frontend/audio/audio_cache.h b/runtime/engine/common/frontend/audio_cache.h similarity index 89% rename from speechx/speechx/frontend/audio/audio_cache.h rename to runtime/engine/common/frontend/audio_cache.h index 4708a6e0..fdc4fdf4 100644 --- a/speechx/speechx/frontend/audio/audio_cache.h +++ b/runtime/engine/common/frontend/audio_cache.h @@ -16,7 +16,7 @@ #pragma once #include "base/common.h" -#include "frontend/audio/frontend_itf.h" +#include "frontend/frontend_itf.h" namespace ppspeech { @@ -26,9 +26,9 @@ class AudioCache : public FrontendInterface { explicit AudioCache(int buffer_size = 1000 * kint16max, bool to_float32 = false); - virtual void Accept(const kaldi::VectorBase& waves); + virtual void Accept(const std::vector& waves); - virtual bool Read(kaldi::Vector* waves); + virtual bool Read(std::vector* waves); // the audio dim is 1, one sample, which is useless, // so we return size_(cache samples) instead. @@ -39,7 +39,7 @@ class AudioCache : public FrontendInterface { finished_ = true; } - virtual bool IsFinished() const { return finished_; } + virtual bool IsFinished() const { return finished_ && (size_ == 0); } void Reset() override { offset_ = 0; diff --git a/runtime/engine/common/frontend/cmvn.cc b/runtime/engine/common/frontend/cmvn.cc new file mode 100644 index 00000000..0f110820 --- /dev/null +++ b/runtime/engine/common/frontend/cmvn.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "frontend/cmvn.h" + +#include "utils/file_utils.h" +#include "utils/picojson.h" + +namespace ppspeech { + +using kaldi::BaseFloat; +using std::unique_ptr; +using std::vector; + + +CMVN::CMVN(std::string cmvn_file, unique_ptr base_extractor) + : var_norm_(true) { + CHECK_NE(cmvn_file, ""); + base_extractor_ = std::move(base_extractor); + ReadCMVNFromJson(cmvn_file); + dim_ = mean_stats_.size() - 1; +} + +void CMVN::ReadCMVNFromJson(std::string cmvn_file) { + std::string json_str = ppspeech::ReadFile2String(cmvn_file); + picojson::value value; + std::string err; + const char* json_end = picojson::parse( + value, json_str.c_str(), json_str.c_str() + json_str.size(), &err); + if (!value.is()) { + LOG(ERROR) << "Input json file format error."; + } + const picojson::value::array& mean_stat = + value.get("mean_stat").get(); + for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) { + mean_stats_.push_back((*it).get()); + } + + const picojson::value::array& var_stat = + value.get("var_stat").get(); + for (auto it = var_stat.begin(); it != var_stat.end(); it++) { + var_stats_.push_back((*it).get()); + } + + kaldi::int32 frame_num = value.get("frame_num").get(); + LOG(INFO) << "nframe: " << frame_num; + mean_stats_.push_back(frame_num); + var_stats_.push_back(0); +} + +void CMVN::Accept(const std::vector& inputs) { + // feed waves/feats to compute feature + base_extractor_->Accept(inputs); + return; +} + +bool CMVN::Read(std::vector* feats) { + // compute feature + if (base_extractor_->Read(feats) == false || feats->size() == 0) { + return false; + } + + // appply cmvn + kaldi::Timer timer; + Compute(feats); + VLOG(1) << "CMVN::Read cost: " << timer.Elapsed() << " sec."; + return true; +} + +// feats contain num_frames feature. +void CMVN::Compute(vector* feats) const { + KALDI_ASSERT(feats != NULL); + + if (feats->size() % dim_ != 0) { + LOG(ERROR) << "Dim mismatch: cmvn " << mean_stats_.size() << ',' + << var_stats_.size() - 1 << ", feats " << feats->size() + << 'x'; + } + if (var_stats_.size() == 0 && var_norm_) { + LOG(ERROR) + << "You requested variance normalization but no variance stats_ " + << "are supplied."; + } + + double count = mean_stats_[dim_]; + // Do not change the threshold of 1.0 here: in the balanced-cmvn code, when + // computing an offset and representing it as stats_, we use a count of one. + if (count < 1.0) + LOG(ERROR) << "Insufficient stats_ for cepstral mean and variance " + "normalization: " + << "count = " << count; + + if (!var_norm_) { + vector offset(feats->size()); + vector mean_stats(mean_stats_); + for (size_t i = 0; i < mean_stats.size(); ++i) { + mean_stats[i] /= count; + } + vector mean_stats_apply(feats->size()); + // fill the datat of mean_stats in mean_stats_appy whose dim_ is equal + // with the dim_ of feature. + // the dim_ of feats = dim_ * num_frames; + for (int32 idx = 0; idx < feats->size() / dim_; ++idx) { + std::memcpy(mean_stats_apply.data() + dim_ * idx, + mean_stats.data(), + dim_ * sizeof(double)); + } + for (size_t idx = 0; idx < feats->size(); ++idx) { + feats->at(idx) += offset[idx]; + } + return; + } + // norm(0, d) = mean offset; + // norm(1, d) = scale, e.g. x(d) <-- x(d)*norm(1, d) + norm(0, d). + vector norm0(feats->size()); + vector norm1(feats->size()); + for (int32 d = 0; d < dim_; d++) { + double mean, offset, scale; + mean = mean_stats_[d] / count; + double var = (var_stats_[d] / count) - mean * mean, floor = 1.0e-20; + if (var < floor) { + LOG(WARNING) << "Flooring cepstral variance from " << var << " to " + << floor; + var = floor; + } + scale = 1.0 / sqrt(var); + if (scale != scale || 1 / scale == 0.0) + LOG(ERROR) + << "NaN or infinity in cepstral mean/variance computation"; + offset = -(mean * scale); + for (int32 d_skip = d; d_skip < feats->size();) { + norm0[d_skip] = offset; + norm1[d_skip] = scale; + d_skip = d_skip + dim_; + } + } + // Apply the normalization. + for (size_t idx = 0; idx < feats->size(); ++idx) { + feats->at(idx) *= norm1[idx]; + } + + for (size_t idx = 0; idx < feats->size(); ++idx) { + feats->at(idx) += norm0[idx]; + } +} + +} // namespace ppspeech diff --git a/speechx/speechx/frontend/audio/cmvn.h b/runtime/engine/common/frontend/cmvn.h similarity index 77% rename from speechx/speechx/frontend/audio/cmvn.h rename to runtime/engine/common/frontend/cmvn.h index 50ef5649..c515b6ae 100644 --- a/speechx/speechx/frontend/audio/cmvn.h +++ b/runtime/engine/common/frontend/cmvn.h @@ -15,8 +15,7 @@ #pragma once #include "base/common.h" -#include "frontend/audio/frontend_itf.h" -#include "kaldi/matrix/kaldi-matrix.h" +#include "frontend/frontend_itf.h" #include "kaldi/util/options-itf.h" namespace ppspeech { @@ -25,11 +24,11 @@ class CMVN : public FrontendInterface { public: explicit CMVN(std::string cmvn_file, std::unique_ptr base_extractor); - virtual void Accept(const kaldi::VectorBase& inputs); + virtual void Accept(const std::vector& inputs); // the length of feats = feature_row * feature_dim, // the Matrix is squashed into Vector - virtual bool Read(kaldi::Vector* feats); + virtual bool Read(std::vector* feats); // the dim_ is the feautre dim. virtual size_t Dim() const { return dim_; } virtual void SetFinished() { base_extractor_->SetFinished(); } @@ -37,9 +36,10 @@ class CMVN : public FrontendInterface { virtual void Reset() { base_extractor_->Reset(); } private: - void Compute(kaldi::VectorBase* feats) const; - void ApplyCMVN(kaldi::MatrixBase* feats); - kaldi::Matrix stats_; + void ReadCMVNFromJson(std::string cmvn_file); + void Compute(std::vector* feats) const; + std::vector mean_stats_; + std::vector var_stats_; std::unique_ptr base_extractor_; size_t dim_; bool var_norm_; diff --git a/speechx/speechx/frontend/audio/compute_fbank_main.cc b/runtime/engine/common/frontend/compute_fbank_main.cc similarity index 89% rename from speechx/speechx/frontend/audio/compute_fbank_main.cc rename to runtime/engine/common/frontend/compute_fbank_main.cc index e2b54a8a..e022207d 100644 --- a/speechx/speechx/frontend/audio/compute_fbank_main.cc +++ b/runtime/engine/common/frontend/compute_fbank_main.cc @@ -16,13 +16,13 @@ #include "base/flags.h" #include "base/log.h" -#include "frontend/audio/audio_cache.h" -#include "frontend/audio/data_cache.h" -#include "frontend/audio/fbank.h" -#include "frontend/audio/feature_cache.h" -#include "frontend/audio/frontend_itf.h" -#include "frontend/audio/normalizer.h" -#include "kaldi/feat/wave-reader.h" +#include "frontend/audio_cache.h" +#include "frontend/data_cache.h" +#include "frontend/fbank.h" +#include "frontend/feature_cache.h" +#include "frontend/frontend_itf.h" +#include "frontend/normalizer.h" +#include "frontend/wave-reader.h" #include "kaldi/util/kaldi-io.h" #include "kaldi/util/table-types.h" @@ -56,7 +56,7 @@ int main(int argc, char* argv[]) { std::unique_ptr data_source( new ppspeech::AudioCache(3600 * 1600, false)); - kaldi::FbankOptions opt; + knf::FbankOptions opt; opt.frame_opts.frame_length_ms = 25; opt.frame_opts.frame_shift_ms = 10; opt.mel_opts.num_bins = FLAGS_num_bins; @@ -73,8 +73,7 @@ int main(int argc, char* argv[]) { new ppspeech::CMVN(FLAGS_cmvn_file, std::move(fbank))); // the feature cache output feature chunk by chunk. - ppspeech::FeatureCacheOptions feat_cache_opts; - ppspeech::FeatureCache feature_cache(feat_cache_opts, std::move(cmvn)); + ppspeech::FeatureCache feature_cache(kint16max, std::move(cmvn)); LOG(INFO) << "fbank: " << true; LOG(INFO) << "feat dim: " << feature_cache.Dim(); @@ -117,9 +116,9 @@ int main(int argc, char* argv[]) { std::min(chunk_sample_size, tot_samples - sample_offset); // get chunk wav - kaldi::Vector wav_chunk(cur_chunk_size); + std::vector wav_chunk(cur_chunk_size); for (int i = 0; i < cur_chunk_size; ++i) { - wav_chunk(i) = waveform(sample_offset + i); + wav_chunk[i] = waveform(sample_offset + i); } // compute feat @@ -131,10 +130,14 @@ int main(int argc, char* argv[]) { } // read feat - kaldi::Vector features; + kaldi::Vector features(feature_cache.Dim()); bool flag = true; do { - flag = feature_cache.Read(&features); + std::vector tmp; + flag = feature_cache.Read(&tmp); + std::memcpy(features.Data(), + tmp.data(), + tmp.size() * sizeof(BaseFloat)); if (flag && features.Dim() != 0) { feats.push_back(features); feature_rows += features.Dim() / feature_cache.Dim(); diff --git a/speechx/speechx/frontend/audio/compute_linear_spectrogram_main.cc b/runtime/engine/common/frontend/compute_linear_spectrogram_main.cc similarity index 100% rename from speechx/speechx/frontend/audio/compute_linear_spectrogram_main.cc rename to runtime/engine/common/frontend/compute_linear_spectrogram_main.cc diff --git a/speechx/speechx/frontend/audio/data_cache.h b/runtime/engine/common/frontend/data_cache.h similarity index 79% rename from speechx/speechx/frontend/audio/data_cache.h rename to runtime/engine/common/frontend/data_cache.h index 5fe5e4fe..7a37adf4 100644 --- a/speechx/speechx/frontend/audio/data_cache.h +++ b/runtime/engine/common/frontend/data_cache.h @@ -15,10 +15,10 @@ #pragma once - #include "base/common.h" -#include "frontend/audio/frontend_itf.h" +#include "frontend/frontend_itf.h" +using std::vector; namespace ppspeech { @@ -30,16 +30,16 @@ class DataCache : public FrontendInterface { DataCache() : finished_{false}, dim_{0} {} // accept waves/feats - void Accept(const kaldi::VectorBase& inputs) override { - data_ = inputs; + void Accept(const std::vector& inputs) override { + data_ = std::move(inputs); } - bool Read(kaldi::Vector* feats) override { - if (data_.Dim() == 0) { + bool Read(vector* feats) override { + if (data_.size() == 0) { return false; } - (*feats) = data_; - data_.Resize(0); + (*feats) = std::move(data_); + data_.resize(0); return true; } @@ -53,7 +53,7 @@ class DataCache : public FrontendInterface { } private: - kaldi::Vector data_; + std::vector data_; bool finished_; int32 dim_; diff --git a/speechx/speechx/frontend/audio/db_norm.cc b/runtime/engine/common/frontend/db_norm.cc similarity index 97% rename from speechx/speechx/frontend/audio/db_norm.cc rename to runtime/engine/common/frontend/db_norm.cc index ad79fcc3..7141fc80 100644 --- a/speechx/speechx/frontend/audio/db_norm.cc +++ b/runtime/engine/common/frontend/db_norm.cc @@ -76,7 +76,7 @@ bool DecibelNormalizer::Compute(VectorBase* waves) const { if (gain > opts_.max_gain_db) { LOG(ERROR) << "Unable to normalize segment to " << opts_.target_db << "dB," - << "because the the probable gain have exceeds opts_.max_gain_db" + << "because the probable gain has exceeded opts_.max_gain_db" << opts_.max_gain_db << "dB."; return false; } diff --git a/speechx/speechx/frontend/audio/db_norm.h b/runtime/engine/common/frontend/db_norm.h similarity index 100% rename from speechx/speechx/frontend/audio/db_norm.h rename to runtime/engine/common/frontend/db_norm.h diff --git a/runtime/engine/common/frontend/fbank.h b/runtime/engine/common/frontend/fbank.h new file mode 100644 index 00000000..4398e72f --- /dev/null +++ b/runtime/engine/common/frontend/fbank.h @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "base/common.h" +#include "frontend/feature-fbank.h" +#include "frontend/feature_common.h" + +namespace ppspeech { + +typedef StreamingFeatureTpl Fbank; + +} // namespace ppspeech diff --git a/runtime/engine/common/frontend/feature-fbank.cc b/runtime/engine/common/frontend/feature-fbank.cc new file mode 100644 index 00000000..2393e153 --- /dev/null +++ b/runtime/engine/common/frontend/feature-fbank.cc @@ -0,0 +1,123 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/feature-fbank.cc +// +#include "frontend/feature-fbank.h" + +#include + +#include "frontend/feature-functions.h" + +namespace knf { + +static void Sqrt(float *in_out, int32_t n) { + for (int32_t i = 0; i != n; ++i) { + in_out[i] = std::sqrt(in_out[i]); + } +} + +std::ostream &operator<<(std::ostream &os, const FbankOptions &opts) { + os << opts.ToString(); + return os; +} + +FbankComputer::FbankComputer(const FbankOptions &opts) + : opts_(opts), rfft_(opts.frame_opts.PaddedWindowSize()) { + if (opts.energy_floor > 0.0f) { + log_energy_floor_ = logf(opts.energy_floor); + } + + // We'll definitely need the filterbanks info for VTLN warping factor 1.0. + // [note: this call caches it.] + GetMelBanks(1.0f); +} + +FbankComputer::~FbankComputer() { + for (auto iter = mel_banks_.begin(); iter != mel_banks_.end(); ++iter) + delete iter->second; +} + +const MelBanks *FbankComputer::GetMelBanks(float vtln_warp) { + MelBanks *this_mel_banks = nullptr; + + // std::map::iterator iter = mel_banks_.find(vtln_warp); + auto iter = mel_banks_.find(vtln_warp); + if (iter == mel_banks_.end()) { + this_mel_banks = + new MelBanks(opts_.mel_opts, opts_.frame_opts, vtln_warp); + mel_banks_[vtln_warp] = this_mel_banks; + } else { + this_mel_banks = iter->second; + } + return this_mel_banks; +} + +void FbankComputer::Compute(float signal_raw_log_energy, + float vtln_warp, + std::vector *signal_frame, + float *feature) { + const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); + + CHECK_EQ(signal_frame->size(), opts_.frame_opts.PaddedWindowSize()); + + // Compute energy after window function (not the raw one). + if (opts_.use_energy && !opts_.raw_energy) { + signal_raw_log_energy = + std::log(std::max(InnerProduct(signal_frame->data(), + signal_frame->data(), + signal_frame->size()), + std::numeric_limits::epsilon())); + } + rfft_.Compute(signal_frame->data()); // signal_frame is modified in-place + ComputePowerSpectrum(signal_frame); + + // Use magnitude instead of power if requested. + if (!opts_.use_power) { + Sqrt(signal_frame->data(), signal_frame->size() / 2 + 1); + } + + int32_t mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0); + + // Its length is opts_.mel_opts.num_bins + float *mel_energies = feature + mel_offset; + + // Sum with mel filter banks over the power spectrum + mel_banks.Compute(signal_frame->data(), mel_energies); + + if (opts_.use_log_fbank) { + // Avoid log of zero (which should be prevented anyway by dithering). + for (int32_t i = 0; i != opts_.mel_opts.num_bins; ++i) { + auto t = std::max(mel_energies[i], + std::numeric_limits::epsilon()); + mel_energies[i] = std::log(t); + } + } + + // Copy energy as first value (or the last, if htk_compat == true). + if (opts_.use_energy) { + if (opts_.energy_floor > 0.0 && + signal_raw_log_energy < log_energy_floor_) { + signal_raw_log_energy = log_energy_floor_; + } + int32_t energy_index = opts_.htk_compat ? opts_.mel_opts.num_bins : 0; + feature[energy_index] = signal_raw_log_energy; + } +} + +} // namespace knf diff --git a/runtime/engine/common/frontend/feature-fbank.h b/runtime/engine/common/frontend/feature-fbank.h new file mode 100644 index 00000000..3dab793f --- /dev/null +++ b/runtime/engine/common/frontend/feature-fbank.h @@ -0,0 +1,138 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/feature-fbank.h + +#ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ +#define KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ + +#include +#include + +#include "frontend/feature-window.h" +#include "frontend/mel-computations.h" +#include "frontend/rfft.h" + +namespace knf { + +struct FbankOptions { + FrameExtractionOptions frame_opts; + MelBanksOptions mel_opts; + // append an extra dimension with energy to the filter banks + bool use_energy = false; + float energy_floor = 0.0f; // active iff use_energy==true + + // If true, compute log_energy before preemphasis and windowing + // If false, compute log_energy after preemphasis ans windowing + bool raw_energy = true; // active iff use_energy==true + + // If true, put energy last (if using energy) + // If false, put energy first + bool htk_compat = false; // active iff use_energy==true + + // if true (default), produce log-filterbank, else linear + bool use_log_fbank = true; + + // if true (default), use power in filterbank + // analysis, else magnitude. + bool use_power = true; + + FbankOptions() { mel_opts.num_bins = 23; } + + std::string ToString() const { + std::ostringstream os; + os << "frame_opts: \n"; + os << frame_opts << "\n"; + os << "\n"; + + os << "mel_opts: \n"; + os << mel_opts << "\n"; + + os << "use_energy: " << use_energy << "\n"; + os << "energy_floor: " << energy_floor << "\n"; + os << "raw_energy: " << raw_energy << "\n"; + os << "htk_compat: " << htk_compat << "\n"; + os << "use_log_fbank: " << use_log_fbank << "\n"; + os << "use_power: " << use_power << "\n"; + return os.str(); + } +}; + +std::ostream &operator<<(std::ostream &os, const FbankOptions &opts); + +class FbankComputer { + public: + using Options = FbankOptions; + + explicit FbankComputer(const FbankOptions &opts); + ~FbankComputer(); + + int32_t Dim() const { + return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); + } + + // if true, compute log_energy_pre_window but after dithering and dc removal + bool NeedRawLogEnergy() const { + return opts_.use_energy && opts_.raw_energy; + } + + const FrameExtractionOptions &GetFrameOptions() const { + return opts_.frame_opts; + } + + const FbankOptions &GetOptions() const { return opts_; } + + /** + Function that computes one frame of features from + one frame of signal. + + @param [in] signal_raw_log_energy The log-energy of the frame of the + signal + prior to windowing and pre-emphasis, or + log(numeric_limits::min()), whichever is greater. Must be + ignored by this function if this class returns false from + this->NeedsRawLogEnergy(). + @param [in] vtln_warp The VTLN warping factor that the user wants + to be applied when computing features for this utterance. Will + normally be 1.0, meaning no warping is to be done. The value will + be ignored for feature types that don't support VLTN, such as + spectrogram features. + @param [in] signal_frame One frame of the signal, + as extracted using the function ExtractWindow() using the options + returned by this->GetFrameOptions(). The function will use the + vector as a workspace, which is why it's a non-const pointer. + @param [out] feature Pointer to a vector of size this->Dim(), to which + the computed feature will be written. It should be pre-allocated. + */ + void Compute(float signal_raw_log_energy, + float vtln_warp, + std::vector *signal_frame, + float *feature); + + private: + const MelBanks *GetMelBanks(float vtln_warp); + + FbankOptions opts_; + float log_energy_floor_; + std::map mel_banks_; // float is VTLN coefficient. + Rfft rfft_; +}; + +} // namespace knf + +#endif // KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ diff --git a/runtime/engine/common/frontend/feature-functions.cc b/runtime/engine/common/frontend/feature-functions.cc new file mode 100644 index 00000000..178c711b --- /dev/null +++ b/runtime/engine/common/frontend/feature-functions.cc @@ -0,0 +1,49 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/feature-functions.cc + +#include "frontend/feature-functions.h" + +#include +#include + +namespace knf { + +void ComputePowerSpectrum(std::vector *complex_fft) { + int32_t dim = complex_fft->size(); + + // now we have in complex_fft, first half of complex spectrum + // it's stored as [real0, realN/2, real1, im1, real2, im2, ...] + + float *p = complex_fft->data(); + int32_t half_dim = dim / 2; + float first_energy = p[0] * p[0]; + float last_energy = p[1] * p[1]; // handle this special case + + for (int32_t i = 1; i < half_dim; ++i) { + float real = p[i * 2]; + float im = p[i * 2 + 1]; + p[i] = real * real + im * im; + } + p[0] = first_energy; + p[half_dim] = last_energy; // Will actually never be used, and anyway + // if the signal has been bandlimited sensibly this should be zero. +} + +} // namespace knf diff --git a/runtime/engine/common/frontend/feature-functions.h b/runtime/engine/common/frontend/feature-functions.h new file mode 100644 index 00000000..852d0612 --- /dev/null +++ b/runtime/engine/common/frontend/feature-functions.h @@ -0,0 +1,38 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/feature-functions.h +#ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H +#define KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H + +#include +namespace knf { + +// ComputePowerSpectrum converts a complex FFT (as produced by the FFT +// functions in csrc/rfft.h), and converts it into +// a power spectrum. If the complex FFT is a vector of size n (representing +// half of the complex FFT of a real signal of size n, as described there), +// this function computes in the first (n/2) + 1 elements of it, the +// energies of the fft bins from zero to the Nyquist frequency. Contents of the +// remaining (n/2) - 1 elements are undefined at output. + +void ComputePowerSpectrum(std::vector *complex_fft); + +} // namespace knf + +#endif // KALDI_NATIVE_FBANK_CSRC_FEATURE_FUNCTIONS_H diff --git a/runtime/engine/common/frontend/feature-window.cc b/runtime/engine/common/frontend/feature-window.cc new file mode 100644 index 00000000..43c736e0 --- /dev/null +++ b/runtime/engine/common/frontend/feature-window.cc @@ -0,0 +1,248 @@ +// kaldi-native-fbank/csrc/feature-window.cc +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/feature-window.cc + +#include "frontend/feature-window.h" + +#include +#include +#include + +#ifndef M_2PI +#define M_2PI 6.283185307179586476925286766559005 +#endif + +namespace knf { + +std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts) { + os << opts.ToString(); + return os; +} + +FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts) + : window_(opts.WindowSize()) { + int32_t frame_length = opts.WindowSize(); + CHECK_GT(frame_length, 0); + + float *window_data = window_.data(); + + double a = M_2PI / (frame_length - 1); + for (int32_t i = 0; i < frame_length; i++) { + double i_fl = static_cast(i); + if (opts.window_type == "hanning") { + window_data[i] = 0.5 - 0.5 * cos(a * i_fl); + } else if (opts.window_type == "sine") { + // when you are checking ws wikipedia, please + // note that 0.5 * a = M_PI/(frame_length-1) + window_data[i] = sin(0.5 * a * i_fl); + } else if (opts.window_type == "hamming") { + window_data[i] = 0.54 - 0.46 * cos(a * i_fl); + } else if (opts.window_type == + "povey") { // like hamming but goes to zero at edges. + window_data[i] = pow(0.5 - 0.5 * cos(a * i_fl), 0.85); + } else if (opts.window_type == "rectangular") { + window_data[i] = 1.0; + } else if (opts.window_type == "blackman") { + window_data[i] = opts.blackman_coeff - 0.5 * cos(a * i_fl) + + (0.5 - opts.blackman_coeff) * cos(2 * a * i_fl); + } else { + LOG(FATAL) << "Invalid window type " << opts.window_type; + } + } +} + +void FeatureWindowFunction::Apply(float *wave) const { + int32_t window_size = window_.size(); + const float *p = window_.data(); + for (int32_t k = 0; k != window_size; ++k) { + wave[k] *= p[k]; + } +} + +int64_t FirstSampleOfFrame(int32_t frame, const FrameExtractionOptions &opts) { + int64_t frame_shift = opts.WindowShift(); + if (opts.snip_edges) { + return frame * frame_shift; + } else { + int64_t midpoint_of_frame = frame_shift * frame + frame_shift / 2, + beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2; + return beginning_of_frame; + } +} + +int32_t NumFrames(int64_t num_samples, + const FrameExtractionOptions &opts, + bool flush /*= true*/) { + int64_t frame_shift = opts.WindowShift(); + int64_t frame_length = opts.WindowSize(); + if (opts.snip_edges) { + // with --snip-edges=true (the default), we use a HTK-like approach to + // determining the number of frames-- all frames have to fit completely + // into + // the waveform, and the first frame begins at sample zero. + if (num_samples < frame_length) + return 0; + else + return (1 + ((num_samples - frame_length) / frame_shift)); + // You can understand the expression above as follows: 'num_samples - + // frame_length' is how much room we have to shift the frame within the + // waveform; 'frame_shift' is how much we shift it each time; and the + // ratio + // is how many times we can shift it (integer arithmetic rounds down). + } else { + // if --snip-edges=false, the number of frames is determined by rounding + // the + // (file-length / frame-shift) to the nearest integer. The point of + // this + // formula is to make the number of frames an obvious and predictable + // function of the frame shift and signal length, which makes many + // segmentation-related questions simpler. + // + // Because integer division in C++ rounds toward zero, we add (half the + // frame-shift minus epsilon) before dividing, to have the effect of + // rounding towards the closest integer. + int32_t num_frames = (num_samples + (frame_shift / 2)) / frame_shift; + + if (flush) return num_frames; + + // note: 'end' always means the last plus one, i.e. one past the last. + int64_t end_sample_of_last_frame = + FirstSampleOfFrame(num_frames - 1, opts) + frame_length; + + // the following code is optimized more for clarity than efficiency. + // If flush == false, we can't output frames that extend past the end + // of the signal. + while (num_frames > 0 && end_sample_of_last_frame > num_samples) { + num_frames--; + end_sample_of_last_frame -= frame_shift; + } + return num_frames; + } +} + +void ExtractWindow(int64_t sample_offset, + const std::vector &wave, + int32_t f, + const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + std::vector *window, + float *log_energy_pre_window /*= nullptr*/) { + CHECK(sample_offset >= 0 && wave.size() != 0); + + int32_t frame_length = opts.WindowSize(); + int32_t frame_length_padded = opts.PaddedWindowSize(); + + int64_t num_samples = sample_offset + wave.size(); + int64_t start_sample = FirstSampleOfFrame(f, opts); + int64_t end_sample = start_sample + frame_length; + + if (opts.snip_edges) { + CHECK(start_sample >= sample_offset && end_sample <= num_samples); + } else { + CHECK(sample_offset == 0 || start_sample >= sample_offset); + } + + if (window->size() != frame_length_padded) { + window->resize(frame_length_padded); + } + + // wave_start and wave_end are start and end indexes into 'wave', for the + // piece of wave that we're trying to extract. + int32_t wave_start = int32_t(start_sample - sample_offset); + int32_t wave_end = wave_start + frame_length; + + if (wave_start >= 0 && wave_end <= wave.size()) { + // the normal case-- no edge effects to consider. + std::copy(wave.begin() + wave_start, + wave.begin() + wave_start + frame_length, + window->data()); + } else { + // Deal with any end effects by reflection, if needed. This code will + // only + // be reached for about two frames per utterance, so we don't concern + // ourselves excessively with efficiency. + int32_t wave_dim = wave.size(); + for (int32_t s = 0; s < frame_length; ++s) { + int32_t s_in_wave = s + wave_start; + while (s_in_wave < 0 || s_in_wave >= wave_dim) { + // reflect around the beginning or end of the wave. + // e.g. -1 -> 0, -2 -> 1. + // dim -> dim - 1, dim + 1 -> dim - 2. + // the code supports repeated reflections, although this + // would only be needed in pathological cases. + if (s_in_wave < 0) + s_in_wave = -s_in_wave - 1; + else + s_in_wave = 2 * wave_dim - 1 - s_in_wave; + } + (*window)[s] = wave[s_in_wave]; + } + } + + ProcessWindow(opts, window_function, window->data(), log_energy_pre_window); +} + +static void RemoveDcOffset(float *d, int32_t n) { + float sum = 0; + for (int32_t i = 0; i != n; ++i) { + sum += d[i]; + } + + float mean = sum / n; + + for (int32_t i = 0; i != n; ++i) { + d[i] -= mean; + } +} + +float InnerProduct(const float *a, const float *b, int32_t n) { + float sum = 0; + for (int32_t i = 0; i != n; ++i) { + sum += a[i] * b[i]; + } + return sum; +} + +static void Preemphasize(float *d, int32_t n, float preemph_coeff) { + if (preemph_coeff == 0.0) { + return; + } + + CHECK(preemph_coeff >= 0.0 && preemph_coeff <= 1.0); + + for (int32_t i = n - 1; i > 0; --i) { + d[i] -= preemph_coeff * d[i - 1]; + } + d[0] -= preemph_coeff * d[0]; +} + +void ProcessWindow(const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + float *window, + float *log_energy_pre_window /*= nullptr*/) { + int32_t frame_length = opts.WindowSize(); + + // TODO(fangjun): Remove dither + CHECK_EQ(opts.dither, 0); + + if (opts.remove_dc_offset) { + RemoveDcOffset(window, frame_length); + } + + if (log_energy_pre_window != NULL) { + float energy = + std::max(InnerProduct(window, window, frame_length), + std::numeric_limits::epsilon()); + *log_energy_pre_window = std::log(energy); + } + + if (opts.preemph_coeff != 0.0) { + Preemphasize(window, frame_length, opts.preemph_coeff); + } + + window_function.Apply(window); +} + +} // namespace knf diff --git a/runtime/engine/common/frontend/feature-window.h b/runtime/engine/common/frontend/feature-window.h new file mode 100644 index 00000000..8c86bf05 --- /dev/null +++ b/runtime/engine/common/frontend/feature-window.h @@ -0,0 +1,183 @@ +// kaldi-native-fbank/csrc/feature-window.h +// +// Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + +// This file is copied/modified from kaldi/src/feat/feature-window.h + +#ifndef KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_ +#define KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_ + +#include +#include +#include + +#include "base/log.h" + +namespace knf { + +inline int32_t RoundUpToNearestPowerOfTwo(int32_t n) { + // copied from kaldi/src/base/kaldi-math.cc + CHECK_GT(n, 0); + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + return n + 1; +} + +struct FrameExtractionOptions { + float samp_freq = 16000; + float frame_shift_ms = 10.0f; // in milliseconds. + float frame_length_ms = 25.0f; // in milliseconds. + float dither = 1.0f; // Amount of dithering, 0.0 means no dither. + float preemph_coeff = 0.97f; // Preemphasis coefficient. + bool remove_dc_offset = true; // Subtract mean of wave before FFT. + std::string window_type = "povey"; // e.g. Hamming window + // May be "hamming", "rectangular", "povey", "hanning", "sine", "blackman" + // "povey" is a window I made to be similar to Hamming but to go to zero at + // the edges, it's pow((0.5 - 0.5*cos(n/N*2*pi)), 0.85) I just don't think + // the + // Hamming window makes sense as a windowing function. + bool round_to_power_of_two = true; + float blackman_coeff = 0.42f; + bool snip_edges = true; + // bool allow_downsample = false; + // bool allow_upsample = false; + + // Used for streaming feature extraction. It indicates the number + // of feature frames to keep in the recycling vector. -1 means to + // keep all feature frames. + int32_t max_feature_vectors = -1; + + int32_t WindowShift() const { + return static_cast(samp_freq * 0.001f * frame_shift_ms); + } + int32_t WindowSize() const { + return static_cast(samp_freq * 0.001f * frame_length_ms); + } + int32_t PaddedWindowSize() const { + return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize()) + : WindowSize()); + } + std::string ToString() const { + std::ostringstream os; +#define KNF_PRINT(x) os << #x << ": " << x << "\n" + KNF_PRINT(samp_freq); + KNF_PRINT(frame_shift_ms); + KNF_PRINT(frame_length_ms); + KNF_PRINT(dither); + KNF_PRINT(preemph_coeff); + KNF_PRINT(remove_dc_offset); + KNF_PRINT(window_type); + KNF_PRINT(round_to_power_of_two); + KNF_PRINT(blackman_coeff); + KNF_PRINT(snip_edges); + // KNF_PRINT(allow_downsample); + // KNF_PRINT(allow_upsample); + KNF_PRINT(max_feature_vectors); +#undef KNF_PRINT + return os.str(); + } +}; + +std::ostream &operator<<(std::ostream &os, const FrameExtractionOptions &opts); + +class FeatureWindowFunction { + public: + FeatureWindowFunction() = default; + explicit FeatureWindowFunction(const FrameExtractionOptions &opts); + /** + * @param wave Pointer to a 1-D array of shape [window_size]. + * It is modified in-place: wave[i] = wave[i] * window_[i]. + * @param + */ + void Apply(float *wave) const; + + private: + std::vector window_; // of size opts.WindowSize() +}; + +int64_t FirstSampleOfFrame(int32_t frame, const FrameExtractionOptions &opts); + +/** + This function returns the number of frames that we can extract from a wave + file with the given number of samples in it (assumed to have the same + sampling rate as specified in 'opts'). + + @param [in] num_samples The number of samples in the wave file. + @param [in] opts The frame-extraction options class + + @param [in] flush True if we are asserting that this number of samples + is 'all there is', false if we expecting more data to possibly come in. This + only makes a difference to the answer + if opts.snips_edges== false. For offline feature extraction you always want + flush == true. In an online-decoding context, once you know (or decide) that + no more data is coming in, you'd call it with flush == true at the end to + flush out any remaining data. +*/ +int32_t NumFrames(int64_t num_samples, + const FrameExtractionOptions &opts, + bool flush = true); + +/* + ExtractWindow() extracts a windowed frame of waveform (possibly with a + power-of-two, padded size, depending on the config), including all the + processing done by ProcessWindow(). + + @param [in] sample_offset If 'wave' is not the entire waveform, but + part of it to the left has been discarded, then the + number of samples prior to 'wave' that we have + already discarded. Set this to zero if you are + processing the entire waveform in one piece, or + if you get 'no matching function' compilation + errors when updating the code. + @param [in] wave The waveform + @param [in] f The frame index to be extracted, with + 0 <= f < NumFrames(sample_offset + wave.Dim(), opts, true) + @param [in] opts The options class to be used + @param [in] window_function The windowing function, as derived from the + options class. + @param [out] window The windowed, possibly-padded waveform to be + extracted. Will be resized as needed. + @param [out] log_energy_pre_window If non-NULL, the log-energy of + the signal prior to pre-emphasis and multiplying by + the windowing function will be written to here. +*/ +void ExtractWindow(int64_t sample_offset, + const std::vector &wave, + int32_t f, + const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + std::vector *window, + float *log_energy_pre_window = nullptr); + +/** + This function does all the windowing steps after actually + extracting the windowed signal: depending on the + configuration, it does dithering, dc offset removal, + preemphasis, and multiplication by the windowing function. + @param [in] opts The options class to be used + @param [in] window_function The windowing function-- should have + been initialized using 'opts'. + @param [in,out] window A vector of size opts.WindowSize(). Note: + it will typically be a sub-vector of a larger vector of size + opts.PaddedWindowSize(), with the remaining samples zero, + as the FFT code is more efficient if it operates on data with + power-of-two size. + @param [out] log_energy_pre_window If non-NULL, then after dithering and + DC offset removal, this function will write to this pointer the log of + the total energy (i.e. sum-squared) of the frame. + */ +void ProcessWindow(const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + float *window, + float *log_energy_pre_window = nullptr); + +// Compute the inner product of two vectors +float InnerProduct(const float *a, const float *b, int32_t n); + +} // namespace knf + +#endif // KALDI_NATIVE_FEAT_CSRC_FEATURE_WINDOW_H_ diff --git a/speechx/speechx/frontend/audio/feature_cache.cc b/runtime/engine/common/frontend/feature_cache.cc similarity index 50% rename from speechx/speechx/frontend/audio/feature_cache.cc rename to runtime/engine/common/frontend/feature_cache.cc index 5110d704..650c84cc 100644 --- a/speechx/speechx/frontend/audio/feature_cache.cc +++ b/runtime/engine/common/frontend/feature_cache.cc @@ -12,94 +12,72 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "frontend/audio/feature_cache.h" +#include "frontend/feature_cache.h" namespace ppspeech { using kaldi::BaseFloat; -using kaldi::SubVector; -using kaldi::Vector; -using kaldi::VectorBase; using std::unique_ptr; using std::vector; -FeatureCache::FeatureCache(FeatureCacheOptions opts, +FeatureCache::FeatureCache(size_t max_size, unique_ptr base_extractor) { - max_size_ = opts.max_size; - timeout_ = opts.timeout; // ms + max_size_ = max_size; base_extractor_ = std::move(base_extractor); dim_ = base_extractor_->Dim(); } -void FeatureCache::Accept(const kaldi::VectorBase& inputs) { +void FeatureCache::Accept(const std::vector& inputs) { // read inputs base_extractor_->Accept(inputs); - - // feed current data - bool result = false; - do { - result = Compute(); - } while (result); } // pop feature chunk -bool FeatureCache::Read(kaldi::Vector* feats) { +bool FeatureCache::Read(std::vector* feats) { kaldi::Timer timer; - std::unique_lock lock(mutex_); - while (cache_.empty() && base_extractor_->IsFinished() == false) { - // todo refactor: wait - // ready_read_condition_.wait(lock); - int32 elapsed = static_cast(timer.Elapsed() * 1000); // ms - if (elapsed > timeout_) { - return false; - } - usleep(100); // sleep 0.1 ms + // feed current data + if (cache_.empty()) { + bool result = false; + do { + result = Compute(); + } while (result); } + if (cache_.empty()) return false; // read from cache - feats->Resize(cache_.front().Dim()); - feats->CopyFromVec(cache_.front()); + *feats = cache_.front(); cache_.pop(); - ready_feed_condition_.notify_one(); - VLOG(1) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec."; + VLOG(2) << "FeatureCache::Read cost: " << timer.Elapsed() << " sec."; + VLOG(1) << "FeatureCache::size : " << cache_.size(); return true; } // read all data from base_feature_extractor_ into cache_ bool FeatureCache::Compute() { // compute and feed - Vector feature; + vector feature; bool result = base_extractor_->Read(&feature); - if (result == false || feature.Dim() == 0) return false; + if (result == false || feature.size() == 0) return false; kaldi::Timer timer; - int32 num_chunk = feature.Dim() / dim_; - nframe_ += num_chunk; + int32 num_chunk = feature.size() / dim_; VLOG(3) << "nframe computed: " << nframe_; for (int chunk_idx = 0; chunk_idx < num_chunk; ++chunk_idx) { int32 start = chunk_idx * dim_; - Vector feature_chunk(dim_); - SubVector tmp(feature.Data() + start, dim_); - feature_chunk.CopyFromVec(tmp); - - std::unique_lock lock(mutex_); - while (cache_.size() >= max_size_) { - // cache full, wait - ready_feed_condition_.wait(lock); - } - + vector feature_chunk(feature.data() + start, + feature.data() + start + dim_); // feed cache cache_.push(feature_chunk); - ready_read_condition_.notify_one(); + ++nframe_; } - VLOG(1) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. " + VLOG(2) << "FeatureCache::Compute cost: " << timer.Elapsed() << " sec. " << num_chunk << " feats."; return true; } -} // namespace ppspeech \ No newline at end of file +} // namespace ppspeech diff --git a/speechx/speechx/frontend/audio/feature_cache.h b/runtime/engine/common/frontend/feature_cache.h similarity index 54% rename from speechx/speechx/frontend/audio/feature_cache.h rename to runtime/engine/common/frontend/feature_cache.h index a4ebd604..549a5724 100644 --- a/speechx/speechx/frontend/audio/feature_cache.h +++ b/runtime/engine/common/frontend/feature_cache.h @@ -15,67 +15,51 @@ #pragma once #include "base/common.h" -#include "frontend/audio/frontend_itf.h" +#include "frontend/frontend_itf.h" namespace ppspeech { -struct FeatureCacheOptions { - int32 max_size; - int32 timeout; // ms - FeatureCacheOptions() : max_size(kint16max), timeout(1) {} -}; - class FeatureCache : public FrontendInterface { public: explicit FeatureCache( - FeatureCacheOptions opts, + size_t max_size = kint16max, std::unique_ptr base_extractor = NULL); // Feed feats or waves - virtual void Accept(const kaldi::VectorBase& inputs); + virtual void Accept(const std::vector& inputs); // feats size = num_frames * feat_dim - virtual bool Read(kaldi::Vector* feats); + virtual bool Read(std::vector* feats); // feat dim virtual size_t Dim() const { return dim_; } virtual void SetFinished() { - LOG(INFO) << "set finished"; - // std::unique_lock lock(mutex_); + std::unique_lock lock(mutex_); base_extractor_->SetFinished(); - - // read the last chunk data - Compute(); - // ready_feed_condition_.notify_one(); - LOG(INFO) << "compute last feats done."; } - virtual bool IsFinished() const { return base_extractor_->IsFinished(); } + virtual bool IsFinished() const { + return base_extractor_->IsFinished() && cache_.empty(); + } void Reset() override { - std::queue> empty; + std::queue> empty; + VLOG(1) << "feature cache size: " << cache_.size(); std::swap(cache_, empty); nframe_ = 0; base_extractor_->Reset(); - VLOG(3) << "feature cache reset: cache size: " << cache_.size(); } private: bool Compute(); int32 dim_; - size_t max_size_; // cache capacity - int32 frame_chunk_size_; // window - int32 frame_chunk_stride_; // stride + size_t max_size_; // cache capacity std::unique_ptr base_extractor_; - kaldi::int32 timeout_; // ms - kaldi::Vector remained_feature_; - std::queue> cache_; // feature cache + std::queue> cache_; // feature cache std::mutex mutex_; - std::condition_variable ready_feed_condition_; - std::condition_variable ready_read_condition_; int32 nframe_; // num of feature computed DISALLOW_COPY_AND_ASSIGN(FeatureCache); diff --git a/speechx/speechx/frontend/audio/feature_common.h b/runtime/engine/common/frontend/feature_common.h similarity index 74% rename from speechx/speechx/frontend/audio/feature_common.h rename to runtime/engine/common/frontend/feature_common.h index bad705c9..fcc9100c 100644 --- a/speechx/speechx/frontend/audio/feature_common.h +++ b/runtime/engine/common/frontend/feature_common.h @@ -14,8 +14,8 @@ #pragma once +#include "frontend/feature-window.h" #include "frontend_itf.h" -#include "kaldi/feat/feature-window.h" namespace ppspeech { @@ -25,8 +25,8 @@ class StreamingFeatureTpl : public FrontendInterface { typedef typename F::Options Options; StreamingFeatureTpl(const Options& opts, std::unique_ptr base_extractor); - virtual void Accept(const kaldi::VectorBase& waves); - virtual bool Read(kaldi::Vector* feats); + virtual void Accept(const std::vector& waves); + virtual bool Read(std::vector* feats); // the dim_ is the dim of single frame feature virtual size_t Dim() const { return computer_.Dim(); } @@ -37,19 +37,19 @@ class StreamingFeatureTpl : public FrontendInterface { virtual void Reset() { base_extractor_->Reset(); - remained_wav_.Resize(0); + remained_wav_.resize(0); } private: - bool Compute(const kaldi::Vector& waves, - kaldi::Vector* feats); + bool Compute(const std::vector& waves, + std::vector* feats); Options opts_; std::unique_ptr base_extractor_; - kaldi::FeatureWindowFunction window_function_; - kaldi::Vector remained_wav_; + knf::FeatureWindowFunction window_function_; + std::vector remained_wav_; F computer_; }; } // namespace ppspeech -#include "frontend/audio/feature_common_inl.h" +#include "frontend/feature_common_inl.h" diff --git a/runtime/engine/common/frontend/feature_common_inl.h b/runtime/engine/common/frontend/feature_common_inl.h new file mode 100644 index 00000000..ac239974 --- /dev/null +++ b/runtime/engine/common/frontend/feature_common_inl.h @@ -0,0 +1,102 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +namespace ppspeech { + +template +StreamingFeatureTpl::StreamingFeatureTpl( + const Options& opts, std::unique_ptr base_extractor) + : opts_(opts), computer_(opts), window_function_(opts.frame_opts) { + base_extractor_ = std::move(base_extractor); +} + +template +void StreamingFeatureTpl::Accept( + const std::vector& waves) { + base_extractor_->Accept(waves); +} + +template +bool StreamingFeatureTpl::Read(std::vector* feats) { + std::vector wav(base_extractor_->Dim()); + bool flag = base_extractor_->Read(&wav); + if (flag == false || wav.size() == 0) return false; + + // append remaned waves + int32 wav_len = wav.size(); + int32 left_len = remained_wav_.size(); + std::vector waves(left_len + wav_len); + std::memcpy(waves.data(), + remained_wav_.data(), + left_len * sizeof(kaldi::BaseFloat)); + std::memcpy(waves.data() + left_len, + wav.data(), + wav_len * sizeof(kaldi::BaseFloat)); + + // compute speech feature + Compute(waves, feats); + + // cache remaned waves + knf::FrameExtractionOptions frame_opts = computer_.GetFrameOptions(); + int32 num_frames = knf::NumFrames(waves.size(), frame_opts); + int32 frame_shift = frame_opts.WindowShift(); + int32 left_samples = waves.size() - frame_shift * num_frames; + remained_wav_.resize(left_samples); + std::memcpy(remained_wav_.data(), + waves.data() + frame_shift * num_frames, + left_samples * sizeof(BaseFloat)); + return true; +} + +// Compute feat +template +bool StreamingFeatureTpl::Compute(const std::vector& waves, + std::vector* feats) { + const knf::FrameExtractionOptions& frame_opts = computer_.GetFrameOptions(); + int32 num_samples = waves.size(); + int32 frame_length = frame_opts.WindowSize(); + int32 sample_rate = frame_opts.samp_freq; + if (num_samples < frame_length) { + return true; + } + + int32 num_frames = knf::NumFrames(num_samples, frame_opts); + feats->resize(num_frames * Dim()); + + std::vector window; + bool need_raw_log_energy = computer_.NeedRawLogEnergy(); + for (int32 frame = 0; frame < num_frames; frame++) { + std::fill(window.begin(), window.end(), 0); + kaldi::BaseFloat raw_log_energy = 0.0; + kaldi::BaseFloat vtln_warp = 1.0; + knf::ExtractWindow(0, + waves, + frame, + frame_opts, + window_function_, + &window, + need_raw_log_energy ? &raw_log_energy : NULL); + + std::vector this_feature(computer_.Dim()); + computer_.Compute( + raw_log_energy, vtln_warp, &window, this_feature.data()); + std::memcpy(feats->data() + frame * Dim(), + this_feature.data(), + sizeof(BaseFloat) * Dim()); + } + return true; +} + +} // namespace ppspeech diff --git a/speechx/speechx/frontend/audio/feature_pipeline.cc b/runtime/engine/common/frontend/feature_pipeline.cc similarity index 61% rename from speechx/speechx/frontend/audio/feature_pipeline.cc rename to runtime/engine/common/frontend/feature_pipeline.cc index 2931b96b..7d662bc1 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.cc +++ b/runtime/engine/common/frontend/feature_pipeline.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "frontend/audio/feature_pipeline.h" +#include "frontend/feature_pipeline.h" namespace ppspeech { @@ -21,24 +21,25 @@ using std::unique_ptr; FeaturePipeline::FeaturePipeline(const FeaturePipelineOptions& opts) : opts_(opts) { unique_ptr data_source( - new ppspeech::AudioCache(1000 * kint16max, opts.to_float32)); + new ppspeech::AudioCache(1000 * kint16max, false)); unique_ptr base_feature; - if (opts.use_fbank) { - base_feature.reset( - new ppspeech::Fbank(opts.fbank_opts, std::move(data_source))); - } else { - base_feature.reset(new ppspeech::LinearSpectrogram( - opts.linear_spectrogram_opts, std::move(data_source))); - } + base_feature.reset( + new ppspeech::Fbank(opts.fbank_opts, std::move(data_source))); - CHECK_NE(opts.cmvn_file, ""); - unique_ptr cmvn( - new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); + // CHECK_NE(opts.cmvn_file, ""); + unique_ptr cache; + if (opts.cmvn_file != ""){ + unique_ptr cmvn( + new ppspeech::CMVN(opts.cmvn_file, std::move(base_feature))); - unique_ptr cache( - new ppspeech::FeatureCache(opts.feature_cache_opts, std::move(cmvn))); + cache.reset( + new ppspeech::FeatureCache(kint16max, std::move(cmvn))); + } else { + cache.reset( + new ppspeech::FeatureCache(kint16max, std::move(base_feature))); + } base_extractor_.reset( new ppspeech::Assembler(opts.assembler_opts, std::move(cache))); diff --git a/speechx/speechx/frontend/audio/feature_pipeline.h b/runtime/engine/common/frontend/feature_pipeline.h similarity index 67% rename from speechx/speechx/frontend/audio/feature_pipeline.h rename to runtime/engine/common/frontend/feature_pipeline.h index e83a3f31..7509814f 100644 --- a/speechx/speechx/frontend/audio/feature_pipeline.h +++ b/runtime/engine/common/frontend/feature_pipeline.h @@ -16,17 +16,15 @@ #pragma once -#include "frontend/audio/assembler.h" -#include "frontend/audio/audio_cache.h" -#include "frontend/audio/data_cache.h" -#include "frontend/audio/fbank.h" -#include "frontend/audio/feature_cache.h" -#include "frontend/audio/frontend_itf.h" -#include "frontend/audio/linear_spectrogram.h" -#include "frontend/audio/normalizer.h" +#include "frontend/assembler.h" +#include "frontend/audio_cache.h" +#include "frontend/cmvn.h" +#include "frontend/data_cache.h" +#include "frontend/fbank.h" +#include "frontend/feature_cache.h" +#include "frontend/frontend_itf.h" // feature -DECLARE_bool(use_fbank); DECLARE_bool(fill_zero); DECLARE_int32(num_bins); DECLARE_string(cmvn_file); @@ -40,11 +38,7 @@ namespace ppspeech { struct FeaturePipelineOptions { std::string cmvn_file{}; - bool to_float32{false}; // true, only for linear feature - bool use_fbank{true}; - LinearSpectrogramOptions linear_spectrogram_opts{}; - kaldi::FbankOptions fbank_opts{}; - FeatureCacheOptions feature_cache_opts{}; + knf::FbankOptions fbank_opts{}; AssemblerOptions assembler_opts{}; static FeaturePipelineOptions InitFromFlags() { @@ -53,30 +47,17 @@ struct FeaturePipelineOptions { LOG(INFO) << "cmvn file: " << opts.cmvn_file; // frame options - kaldi::FrameExtractionOptions frame_opts; + knf::FrameExtractionOptions frame_opts; frame_opts.dither = 0.0; LOG(INFO) << "dither: " << frame_opts.dither; frame_opts.frame_shift_ms = 10; LOG(INFO) << "frame shift ms: " << frame_opts.frame_shift_ms; - opts.use_fbank = FLAGS_use_fbank; - LOG(INFO) << "feature type: " << (opts.use_fbank ? "fbank" : "linear"); - if (opts.use_fbank) { - opts.to_float32 = false; - frame_opts.window_type = "povey"; - frame_opts.frame_length_ms = 25; - opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; - LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins; - - opts.fbank_opts.frame_opts = frame_opts; - } else { - opts.to_float32 = true; - frame_opts.remove_dc_offset = false; - frame_opts.frame_length_ms = 20; - frame_opts.window_type = "hanning"; - frame_opts.preemph_coeff = 0.0; - - opts.linear_spectrogram_opts.frame_opts = frame_opts; - } + frame_opts.window_type = "povey"; + frame_opts.frame_length_ms = 25; + opts.fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + LOG(INFO) << "num bins: " << opts.fbank_opts.mel_opts.num_bins; + + opts.fbank_opts.frame_opts = frame_opts; LOG(INFO) << "frame length ms: " << frame_opts.frame_length_ms; // assembler opts @@ -100,10 +81,10 @@ struct FeaturePipelineOptions { class FeaturePipeline : public FrontendInterface { public: explicit FeaturePipeline(const FeaturePipelineOptions& opts); - virtual void Accept(const kaldi::VectorBase& waves) { + virtual void Accept(const std::vector& waves) { base_extractor_->Accept(waves); } - virtual bool Read(kaldi::Vector* feats) { + virtual bool Read(std::vector* feats) { return base_extractor_->Read(feats); } virtual size_t Dim() const { return base_extractor_->Dim(); } diff --git a/runtime/engine/common/frontend/fftsg.c b/runtime/engine/common/frontend/fftsg.c new file mode 100644 index 00000000..30b81604 --- /dev/null +++ b/runtime/engine/common/frontend/fftsg.c @@ -0,0 +1,3271 @@ +/* This file is copied from + * https://www.kurims.kyoto-u.ac.jp/~ooura/fft.html + */ +/* +Fast Fourier/Cosine/Sine Transform + dimension :one + data length :power of 2 + decimation :frequency + radix :split-radix + data :inplace + table :use +functions + cdft: Complex Discrete Fourier Transform + rdft: Real Discrete Fourier Transform + ddct: Discrete Cosine Transform + ddst: Discrete Sine Transform + dfct: Cosine Transform of RDFT (Real Symmetric DFT) + dfst: Sine Transform of RDFT (Real Anti-symmetric DFT) +function prototypes + void cdft(int, int, double *, int *, double *); + void rdft(int, int, double *, int *, double *); + void ddct(int, int, double *, int *, double *); + void ddst(int, int, double *, int *, double *); + void dfct(int, double *, double *, int *, double *); + void dfst(int, double *, double *, int *, double *); +macro definitions + USE_CDFT_PTHREADS : default=not defined + CDFT_THREADS_BEGIN_N : must be >= 512, default=8192 + CDFT_4THREADS_BEGIN_N : must be >= 512, default=65536 + USE_CDFT_WINTHREADS : default=not defined + CDFT_THREADS_BEGIN_N : must be >= 512, default=32768 + CDFT_4THREADS_BEGIN_N : must be >= 512, default=524288 + + +-------- Complex DFT (Discrete Fourier Transform) -------- + [definition] + + X[k] = sum_j=0^n-1 x[j]*exp(2*pi*i*j*k/n), 0<=k + X[k] = sum_j=0^n-1 x[j]*exp(-2*pi*i*j*k/n), 0<=k + ip[0] = 0; // first time only + cdft(2*n, 1, a, ip, w); + + ip[0] = 0; // first time only + cdft(2*n, -1, a, ip, w); + [parameters] + 2*n :data length (int) + n >= 1, n = power of 2 + a[0...2*n-1] :input/output data (double *) + input data + a[2*j] = Re(x[j]), + a[2*j+1] = Im(x[j]), 0<=j= 2+sqrt(n) + strictly, + length of ip >= + 2+(1<<(int)(log(n+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n/2-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + cdft(2*n, -1, a, ip, w); + is + cdft(2*n, 1, a, ip, w); + for (j = 0; j <= 2 * n - 1; j++) { + a[j] *= 1.0 / n; + } + . + + +-------- Real DFT / Inverse of Real DFT -------- + [definition] + RDFT + R[k] = sum_j=0^n-1 a[j]*cos(2*pi*j*k/n), 0<=k<=n/2 + I[k] = sum_j=0^n-1 a[j]*sin(2*pi*j*k/n), 0 IRDFT (excluding scale) + a[k] = (R[0] + R[n/2]*cos(pi*k))/2 + + sum_j=1^n/2-1 R[j]*cos(2*pi*j*k/n) + + sum_j=1^n/2-1 I[j]*sin(2*pi*j*k/n), 0<=k + ip[0] = 0; // first time only + rdft(n, 1, a, ip, w); + + ip[0] = 0; // first time only + rdft(n, -1, a, ip, w); + [parameters] + n :data length (int) + n >= 2, n = power of 2 + a[0...n-1] :input/output data (double *) + + output data + a[2*k] = R[k], 0<=k + input data + a[2*j] = R[j], 0<=j= 2+sqrt(n/2) + strictly, + length of ip >= + 2+(1<<(int)(log(n/2+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n/2-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + rdft(n, 1, a, ip, w); + is + rdft(n, -1, a, ip, w); + for (j = 0; j <= n - 1; j++) { + a[j] *= 2.0 / n; + } + . + + +-------- DCT (Discrete Cosine Transform) / Inverse of DCT -------- + [definition] + IDCT (excluding scale) + C[k] = sum_j=0^n-1 a[j]*cos(pi*j*(k+1/2)/n), 0<=k DCT + C[k] = sum_j=0^n-1 a[j]*cos(pi*(j+1/2)*k/n), 0<=k + ip[0] = 0; // first time only + ddct(n, 1, a, ip, w); + + ip[0] = 0; // first time only + ddct(n, -1, a, ip, w); + [parameters] + n :data length (int) + n >= 2, n = power of 2 + a[0...n-1] :input/output data (double *) + output data + a[k] = C[k], 0<=k= 2+sqrt(n/2) + strictly, + length of ip >= + 2+(1<<(int)(log(n/2+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n*5/4-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + ddct(n, -1, a, ip, w); + is + a[0] *= 0.5; + ddct(n, 1, a, ip, w); + for (j = 0; j <= n - 1; j++) { + a[j] *= 2.0 / n; + } + . + + +-------- DST (Discrete Sine Transform) / Inverse of DST -------- + [definition] + IDST (excluding scale) + S[k] = sum_j=1^n A[j]*sin(pi*j*(k+1/2)/n), 0<=k DST + S[k] = sum_j=0^n-1 a[j]*sin(pi*(j+1/2)*k/n), 0 + ip[0] = 0; // first time only + ddst(n, 1, a, ip, w); + + ip[0] = 0; // first time only + ddst(n, -1, a, ip, w); + [parameters] + n :data length (int) + n >= 2, n = power of 2 + a[0...n-1] :input/output data (double *) + + input data + a[j] = A[j], 0 + output data + a[k] = S[k], 0= 2+sqrt(n/2) + strictly, + length of ip >= + 2+(1<<(int)(log(n/2+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n*5/4-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + ddst(n, -1, a, ip, w); + is + a[0] *= 0.5; + ddst(n, 1, a, ip, w); + for (j = 0; j <= n - 1; j++) { + a[j] *= 2.0 / n; + } + . + + +-------- Cosine Transform of RDFT (Real Symmetric DFT) -------- + [definition] + C[k] = sum_j=0^n a[j]*cos(pi*j*k/n), 0<=k<=n + [usage] + ip[0] = 0; // first time only + dfct(n, a, t, ip, w); + [parameters] + n :data length - 1 (int) + n >= 2, n = power of 2 + a[0...n] :input/output data (double *) + output data + a[k] = C[k], 0<=k<=n + t[0...n/2] :work area (double *) + ip[0...*] :work area for bit reversal (int *) + length of ip >= 2+sqrt(n/4) + strictly, + length of ip >= + 2+(1<<(int)(log(n/4+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n*5/8-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + a[0] *= 0.5; + a[n] *= 0.5; + dfct(n, a, t, ip, w); + is + a[0] *= 0.5; + a[n] *= 0.5; + dfct(n, a, t, ip, w); + for (j = 0; j <= n; j++) { + a[j] *= 2.0 / n; + } + . + + +-------- Sine Transform of RDFT (Real Anti-symmetric DFT) -------- + [definition] + S[k] = sum_j=1^n-1 a[j]*sin(pi*j*k/n), 0= 2, n = power of 2 + a[0...n-1] :input/output data (double *) + output data + a[k] = S[k], 0= 2+sqrt(n/4) + strictly, + length of ip >= + 2+(1<<(int)(log(n/4+0.5)/log(2))/2). + ip[0],ip[1] are pointers of the cos/sin table. + w[0...n*5/8-1] :cos/sin table (double *) + w[],ip[] are initialized if ip[0] == 0. + [remark] + Inverse of + dfst(n, a, t, ip, w); + is + dfst(n, a, t, ip, w); + for (j = 1; j <= n - 1; j++) { + a[j] *= 2.0 / n; + } + . + + +Appendix : + The cos/sin table is recalculated when the larger table required. + w[] and ip[] are compatible with all routines. +*/ + + +void cdft(int n, int isgn, double *a, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void cftbsub(int n, double *a, int *ip, int nw, double *w); + int nw; + + nw = ip[0]; + if (n > (nw << 2)) { + nw = n >> 2; + makewt(nw, ip, w); + } + if (isgn >= 0) { + cftfsub(n, a, ip, nw, w); + } else { + cftbsub(n, a, ip, nw, w); + } +} + + +void rdft(int n, int isgn, double *a, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void cftbsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void rftbsub(int n, double *a, int nc, double *c); + int nw, nc; + double xi; + + nw = ip[0]; + if (n > (nw << 2)) { + nw = n >> 2; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > (nc << 2)) { + nc = n >> 2; + makect(nc, ip, w + nw); + } + if (isgn >= 0) { + if (n > 4) { + cftfsub(n, a, ip, nw, w); + rftfsub(n, a, nc, w + nw); + } else if (n == 4) { + cftfsub(n, a, ip, nw, w); + } + xi = a[0] - a[1]; + a[0] += a[1]; + a[1] = xi; + } else { + a[1] = 0.5 * (a[0] - a[1]); + a[0] -= a[1]; + if (n > 4) { + rftbsub(n, a, nc, w + nw); + cftbsub(n, a, ip, nw, w); + } else if (n == 4) { + cftbsub(n, a, ip, nw, w); + } + } +} + + +void ddct(int n, int isgn, double *a, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void cftbsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void rftbsub(int n, double *a, int nc, double *c); + void dctsub(int n, double *a, int nc, double *c); + int j, nw, nc; + double xr; + + nw = ip[0]; + if (n > (nw << 2)) { + nw = n >> 2; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > nc) { + nc = n; + makect(nc, ip, w + nw); + } + if (isgn < 0) { + xr = a[n - 1]; + for (j = n - 2; j >= 2; j -= 2) { + a[j + 1] = a[j] - a[j - 1]; + a[j] += a[j - 1]; + } + a[1] = a[0] - xr; + a[0] += xr; + if (n > 4) { + rftbsub(n, a, nc, w + nw); + cftbsub(n, a, ip, nw, w); + } else if (n == 4) { + cftbsub(n, a, ip, nw, w); + } + } + dctsub(n, a, nc, w + nw); + if (isgn >= 0) { + if (n > 4) { + cftfsub(n, a, ip, nw, w); + rftfsub(n, a, nc, w + nw); + } else if (n == 4) { + cftfsub(n, a, ip, nw, w); + } + xr = a[0] - a[1]; + a[0] += a[1]; + for (j = 2; j < n; j += 2) { + a[j - 1] = a[j] - a[j + 1]; + a[j] += a[j + 1]; + } + a[n - 1] = xr; + } +} + + +void ddst(int n, int isgn, double *a, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void cftbsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void rftbsub(int n, double *a, int nc, double *c); + void dstsub(int n, double *a, int nc, double *c); + int j, nw, nc; + double xr; + + nw = ip[0]; + if (n > (nw << 2)) { + nw = n >> 2; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > nc) { + nc = n; + makect(nc, ip, w + nw); + } + if (isgn < 0) { + xr = a[n - 1]; + for (j = n - 2; j >= 2; j -= 2) { + a[j + 1] = -a[j] - a[j - 1]; + a[j] -= a[j - 1]; + } + a[1] = a[0] + xr; + a[0] -= xr; + if (n > 4) { + rftbsub(n, a, nc, w + nw); + cftbsub(n, a, ip, nw, w); + } else if (n == 4) { + cftbsub(n, a, ip, nw, w); + } + } + dstsub(n, a, nc, w + nw); + if (isgn >= 0) { + if (n > 4) { + cftfsub(n, a, ip, nw, w); + rftfsub(n, a, nc, w + nw); + } else if (n == 4) { + cftfsub(n, a, ip, nw, w); + } + xr = a[0] - a[1]; + a[0] += a[1]; + for (j = 2; j < n; j += 2) { + a[j - 1] = -a[j] - a[j + 1]; + a[j] -= a[j + 1]; + } + a[n - 1] = -xr; + } +} + + +void dfct(int n, double *a, double *t, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void dctsub(int n, double *a, int nc, double *c); + int j, k, l, m, mh, nw, nc; + double xr, xi, yr, yi; + + nw = ip[0]; + if (n > (nw << 3)) { + nw = n >> 3; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > (nc << 1)) { + nc = n >> 1; + makect(nc, ip, w + nw); + } + m = n >> 1; + yi = a[m]; + xi = a[0] + a[n]; + a[0] -= a[n]; + t[0] = xi - yi; + t[m] = xi + yi; + if (n > 2) { + mh = m >> 1; + for (j = 1; j < mh; j++) { + k = m - j; + xr = a[j] - a[n - j]; + xi = a[j] + a[n - j]; + yr = a[k] - a[n - k]; + yi = a[k] + a[n - k]; + a[j] = xr; + a[k] = yr; + t[j] = xi - yi; + t[k] = xi + yi; + } + t[mh] = a[mh] + a[n - mh]; + a[mh] -= a[n - mh]; + dctsub(m, a, nc, w + nw); + if (m > 4) { + cftfsub(m, a, ip, nw, w); + rftfsub(m, a, nc, w + nw); + } else if (m == 4) { + cftfsub(m, a, ip, nw, w); + } + a[n - 1] = a[0] - a[1]; + a[1] = a[0] + a[1]; + for (j = m - 2; j >= 2; j -= 2) { + a[2 * j + 1] = a[j] + a[j + 1]; + a[2 * j - 1] = a[j] - a[j + 1]; + } + l = 2; + m = mh; + while (m >= 2) { + dctsub(m, t, nc, w + nw); + if (m > 4) { + cftfsub(m, t, ip, nw, w); + rftfsub(m, t, nc, w + nw); + } else if (m == 4) { + cftfsub(m, t, ip, nw, w); + } + a[n - l] = t[0] - t[1]; + a[l] = t[0] + t[1]; + k = 0; + for (j = 2; j < m; j += 2) { + k += l << 2; + a[k - l] = t[j] - t[j + 1]; + a[k + l] = t[j] + t[j + 1]; + } + l <<= 1; + mh = m >> 1; + for (j = 0; j < mh; j++) { + k = m - j; + t[j] = t[m + k] - t[m + j]; + t[k] = t[m + k] + t[m + j]; + } + t[mh] = t[m + mh]; + m = mh; + } + a[l] = t[0]; + a[n] = t[2] - t[1]; + a[0] = t[2] + t[1]; + } else { + a[1] = a[0]; + a[2] = t[0]; + a[0] = t[1]; + } +} + + +void dfst(int n, double *a, double *t, int *ip, double *w) { + void makewt(int nw, int *ip, double *w); + void makect(int nc, int *ip, double *c); + void cftfsub(int n, double *a, int *ip, int nw, double *w); + void rftfsub(int n, double *a, int nc, double *c); + void dstsub(int n, double *a, int nc, double *c); + int j, k, l, m, mh, nw, nc; + double xr, xi, yr, yi; + + nw = ip[0]; + if (n > (nw << 3)) { + nw = n >> 3; + makewt(nw, ip, w); + } + nc = ip[1]; + if (n > (nc << 1)) { + nc = n >> 1; + makect(nc, ip, w + nw); + } + if (n > 2) { + m = n >> 1; + mh = m >> 1; + for (j = 1; j < mh; j++) { + k = m - j; + xr = a[j] + a[n - j]; + xi = a[j] - a[n - j]; + yr = a[k] + a[n - k]; + yi = a[k] - a[n - k]; + a[j] = xr; + a[k] = yr; + t[j] = xi + yi; + t[k] = xi - yi; + } + t[0] = a[mh] - a[n - mh]; + a[mh] += a[n - mh]; + a[0] = a[m]; + dstsub(m, a, nc, w + nw); + if (m > 4) { + cftfsub(m, a, ip, nw, w); + rftfsub(m, a, nc, w + nw); + } else if (m == 4) { + cftfsub(m, a, ip, nw, w); + } + a[n - 1] = a[1] - a[0]; + a[1] = a[0] + a[1]; + for (j = m - 2; j >= 2; j -= 2) { + a[2 * j + 1] = a[j] - a[j + 1]; + a[2 * j - 1] = -a[j] - a[j + 1]; + } + l = 2; + m = mh; + while (m >= 2) { + dstsub(m, t, nc, w + nw); + if (m > 4) { + cftfsub(m, t, ip, nw, w); + rftfsub(m, t, nc, w + nw); + } else if (m == 4) { + cftfsub(m, t, ip, nw, w); + } + a[n - l] = t[1] - t[0]; + a[l] = t[0] + t[1]; + k = 0; + for (j = 2; j < m; j += 2) { + k += l << 2; + a[k - l] = -t[j] - t[j + 1]; + a[k + l] = t[j] - t[j + 1]; + } + l <<= 1; + mh = m >> 1; + for (j = 1; j < mh; j++) { + k = m - j; + t[j] = t[m + k] + t[m + j]; + t[k] = t[m + k] - t[m + j]; + } + t[0] = t[m + mh]; + m = mh; + } + a[l] = t[0]; + } + a[0] = 0; +} + + +/* -------- initializing routines -------- */ + + +#include + +void makewt(int nw, int *ip, double *w) { + void makeipt(int nw, int *ip); + int j, nwh, nw0, nw1; + double delta, wn4r, wk1r, wk1i, wk3r, wk3i; + + ip[0] = nw; + ip[1] = 1; + if (nw > 2) { + nwh = nw >> 1; + delta = atan(1.0) / nwh; + wn4r = cos(delta * nwh); + w[0] = 1; + w[1] = wn4r; + if (nwh == 4) { + w[2] = cos(delta * 2); + w[3] = sin(delta * 2); + } else if (nwh > 4) { + makeipt(nw, ip); + w[2] = 0.5 / cos(delta * 2); + w[3] = 0.5 / cos(delta * 6); + for (j = 4; j < nwh; j += 4) { + w[j] = cos(delta * j); + w[j + 1] = sin(delta * j); + w[j + 2] = cos(3 * delta * j); + w[j + 3] = -sin(3 * delta * j); + } + } + nw0 = 0; + while (nwh > 2) { + nw1 = nw0 + nwh; + nwh >>= 1; + w[nw1] = 1; + w[nw1 + 1] = wn4r; + if (nwh == 4) { + wk1r = w[nw0 + 4]; + wk1i = w[nw0 + 5]; + w[nw1 + 2] = wk1r; + w[nw1 + 3] = wk1i; + } else if (nwh > 4) { + wk1r = w[nw0 + 4]; + wk3r = w[nw0 + 6]; + w[nw1 + 2] = 0.5 / wk1r; + w[nw1 + 3] = 0.5 / wk3r; + for (j = 4; j < nwh; j += 4) { + wk1r = w[nw0 + 2 * j]; + wk1i = w[nw0 + 2 * j + 1]; + wk3r = w[nw0 + 2 * j + 2]; + wk3i = w[nw0 + 2 * j + 3]; + w[nw1 + j] = wk1r; + w[nw1 + j + 1] = wk1i; + w[nw1 + j + 2] = wk3r; + w[nw1 + j + 3] = wk3i; + } + } + nw0 = nw1; + } + } +} + + +void makeipt(int nw, int *ip) { + int j, l, m, m2, p, q; + + ip[2] = 0; + ip[3] = 16; + m = 2; + for (l = nw; l > 32; l >>= 2) { + m2 = m << 1; + q = m2 << 3; + for (j = m; j < m2; j++) { + p = ip[j] << 2; + ip[m + j] = p; + ip[m2 + j] = p + q; + } + m = m2; + } +} + + +void makect(int nc, int *ip, double *c) { + int j, nch; + double delta; + + ip[1] = nc; + if (nc > 1) { + nch = nc >> 1; + delta = atan(1.0) / nch; + c[0] = cos(delta * nch); + c[nch] = 0.5 * c[0]; + for (j = 1; j < nch; j++) { + c[j] = 0.5 * cos(delta * j); + c[nc - j] = 0.5 * sin(delta * j); + } + } +} + + +/* -------- child routines -------- */ + + +#ifdef USE_CDFT_PTHREADS +#define USE_CDFT_THREADS +#ifndef CDFT_THREADS_BEGIN_N +#define CDFT_THREADS_BEGIN_N 8192 +#endif +#ifndef CDFT_4THREADS_BEGIN_N +#define CDFT_4THREADS_BEGIN_N 65536 +#endif +#include +#include +#include +#define cdft_thread_t pthread_t +#define cdft_thread_create(thp, func, argp) \ + { \ + if (pthread_create(thp, NULL, func, (void *)argp) != 0) { \ + fprintf(stderr, "cdft thread error\n"); \ + exit(1); \ + } \ + } +#define cdft_thread_wait(th) \ + { \ + if (pthread_join(th, NULL) != 0) { \ + fprintf(stderr, "cdft thread error\n"); \ + exit(1); \ + } \ + } +#endif /* USE_CDFT_PTHREADS */ + + +#ifdef USE_CDFT_WINTHREADS +#define USE_CDFT_THREADS +#ifndef CDFT_THREADS_BEGIN_N +#define CDFT_THREADS_BEGIN_N 32768 +#endif +#ifndef CDFT_4THREADS_BEGIN_N +#define CDFT_4THREADS_BEGIN_N 524288 +#endif +#include +#include +#include +#define cdft_thread_t HANDLE +#define cdft_thread_create(thp, func, argp) \ + { \ + DWORD thid; \ + *(thp) = CreateThread( \ + NULL, 0, (LPTHREAD_START_ROUTINE)func, (LPVOID)argp, 0, &thid); \ + if (*(thp) == 0) { \ + fprintf(stderr, "cdft thread error\n"); \ + exit(1); \ + } \ + } +#define cdft_thread_wait(th) \ + { \ + WaitForSingleObject(th, INFINITE); \ + CloseHandle(th); \ + } +#endif /* USE_CDFT_WINTHREADS */ + + +void cftfsub(int n, double *a, int *ip, int nw, double *w) { + void bitrv2(int n, int *ip, double *a); + void bitrv216(double *a); + void bitrv208(double *a); + void cftf1st(int n, double *a, double *w); + void cftrec4(int n, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftfx41(int n, double *a, int nw, double *w); + void cftf161(double *a, double *w); + void cftf081(double *a, double *w); + void cftf040(double *a); + void cftx020(double *a); +#ifdef USE_CDFT_THREADS + void cftrec4_th(int n, double *a, int nw, double *w); +#endif /* USE_CDFT_THREADS */ + + if (n > 8) { + if (n > 32) { + cftf1st(n, a, &w[nw - (n >> 2)]); +#ifdef USE_CDFT_THREADS + if (n > CDFT_THREADS_BEGIN_N) { + cftrec4_th(n, a, nw, w); + } else +#endif /* USE_CDFT_THREADS */ + if (n > 512) { + cftrec4(n, a, nw, w); + } else if (n > 128) { + cftleaf(n, 1, a, nw, w); + } else { + cftfx41(n, a, nw, w); + } + bitrv2(n, ip, a); + } else if (n == 32) { + cftf161(a, &w[nw - 8]); + bitrv216(a); + } else { + cftf081(a, w); + bitrv208(a); + } + } else if (n == 8) { + cftf040(a); + } else if (n == 4) { + cftx020(a); + } +} + + +void cftbsub(int n, double *a, int *ip, int nw, double *w) { + void bitrv2conj(int n, int *ip, double *a); + void bitrv216neg(double *a); + void bitrv208neg(double *a); + void cftb1st(int n, double *a, double *w); + void cftrec4(int n, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftfx41(int n, double *a, int nw, double *w); + void cftf161(double *a, double *w); + void cftf081(double *a, double *w); + void cftb040(double *a); + void cftx020(double *a); +#ifdef USE_CDFT_THREADS + void cftrec4_th(int n, double *a, int nw, double *w); +#endif /* USE_CDFT_THREADS */ + + if (n > 8) { + if (n > 32) { + cftb1st(n, a, &w[nw - (n >> 2)]); +#ifdef USE_CDFT_THREADS + if (n > CDFT_THREADS_BEGIN_N) { + cftrec4_th(n, a, nw, w); + } else +#endif /* USE_CDFT_THREADS */ + if (n > 512) { + cftrec4(n, a, nw, w); + } else if (n > 128) { + cftleaf(n, 1, a, nw, w); + } else { + cftfx41(n, a, nw, w); + } + bitrv2conj(n, ip, a); + } else if (n == 32) { + cftf161(a, &w[nw - 8]); + bitrv216neg(a); + } else { + cftf081(a, w); + bitrv208neg(a); + } + } else if (n == 8) { + cftb040(a); + } else if (n == 4) { + cftx020(a); + } +} + + +void bitrv2(int n, int *ip, double *a) { + int j, j1, k, k1, l, m, nh, nm; + double xr, xi, yr, yi; + + m = 1; + for (l = n >> 2; l > 8; l >>= 2) { + m <<= 1; + } + nh = n >> 1; + nm = 4 * m; + if (l == 8) { + for (k = 0; k < m; k++) { + for (j = 0; j < k; j++) { + j1 = 4 * j + 2 * ip[m + k]; + k1 = 4 * k + 2 * ip[m + j]; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh; + k1 += 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += 2; + k1 += nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh; + k1 -= 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + k1 = 4 * k + 2 * ip[m + k]; + j1 = k1 + 2; + k1 += nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= 2; + k1 -= nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh + 2; + k1 += nh + 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh - nm; + k1 += 2 * nm - 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + } else { + for (k = 0; k < m; k++) { + for (j = 0; j < k; j++) { + j1 = 4 * j + ip[m + k]; + k1 = 4 * k + ip[m + j]; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh; + k1 += 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += 2; + k1 += nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh; + k1 -= 2; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + k1 = 4 * k + ip[m + k]; + j1 = k1 + 2; + k1 += nh; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = a[j1 + 1]; + yr = a[k1]; + yi = a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + } +} + + +void bitrv2conj(int n, int *ip, double *a) { + int j, j1, k, k1, l, m, nh, nm; + double xr, xi, yr, yi; + + m = 1; + for (l = n >> 2; l > 8; l >>= 2) { + m <<= 1; + } + nh = n >> 1; + nm = 4 * m; + if (l == 8) { + for (k = 0; k < m; k++) { + for (j = 0; j < k; j++) { + j1 = 4 * j + 2 * ip[m + k]; + k1 = 4 * k + 2 * ip[m + j]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh; + k1 += 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 += nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += 2; + k1 += nh; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh; + k1 -= 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 += nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + k1 = 4 * k + 2 * ip[m + k]; + j1 = k1 + 2; + k1 += nh; + a[j1 - 1] = -a[j1 - 1]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + a[k1 + 3] = -a[k1 + 3]; + j1 += nm; + k1 += 2 * nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= 2; + k1 -= nh; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh + 2; + k1 += nh + 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh - nm; + k1 += 2 * nm - 2; + a[j1 - 1] = -a[j1 - 1]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + a[k1 + 3] = -a[k1 + 3]; + } + } else { + for (k = 0; k < m; k++) { + for (j = 0; j < k; j++) { + j1 = 4 * j + ip[m + k]; + k1 = 4 * k + ip[m + j]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nh; + k1 += 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += 2; + k1 += nh; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 += nm; + k1 += nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nh; + k1 -= 2; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + j1 -= nm; + k1 -= nm; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + } + k1 = 4 * k + ip[m + k]; + j1 = k1 + 2; + k1 += nh; + a[j1 - 1] = -a[j1 - 1]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + a[k1 + 3] = -a[k1 + 3]; + j1 += nm; + k1 += nm; + a[j1 - 1] = -a[j1 - 1]; + xr = a[j1]; + xi = -a[j1 + 1]; + yr = a[k1]; + yi = -a[k1 + 1]; + a[j1] = yr; + a[j1 + 1] = yi; + a[k1] = xr; + a[k1 + 1] = xi; + a[k1 + 3] = -a[k1 + 3]; + } + } +} + + +void bitrv216(double *a) { + double x1r, x1i, x2r, x2i, x3r, x3i, x4r, x4i, x5r, x5i, x7r, x7i, x8r, x8i, + x10r, x10i, x11r, x11i, x12r, x12i, x13r, x13i, x14r, x14i; + + x1r = a[2]; + x1i = a[3]; + x2r = a[4]; + x2i = a[5]; + x3r = a[6]; + x3i = a[7]; + x4r = a[8]; + x4i = a[9]; + x5r = a[10]; + x5i = a[11]; + x7r = a[14]; + x7i = a[15]; + x8r = a[16]; + x8i = a[17]; + x10r = a[20]; + x10i = a[21]; + x11r = a[22]; + x11i = a[23]; + x12r = a[24]; + x12i = a[25]; + x13r = a[26]; + x13i = a[27]; + x14r = a[28]; + x14i = a[29]; + a[2] = x8r; + a[3] = x8i; + a[4] = x4r; + a[5] = x4i; + a[6] = x12r; + a[7] = x12i; + a[8] = x2r; + a[9] = x2i; + a[10] = x10r; + a[11] = x10i; + a[14] = x14r; + a[15] = x14i; + a[16] = x1r; + a[17] = x1i; + a[20] = x5r; + a[21] = x5i; + a[22] = x13r; + a[23] = x13i; + a[24] = x3r; + a[25] = x3i; + a[26] = x11r; + a[27] = x11i; + a[28] = x7r; + a[29] = x7i; +} + + +void bitrv216neg(double *a) { + double x1r, x1i, x2r, x2i, x3r, x3i, x4r, x4i, x5r, x5i, x6r, x6i, x7r, x7i, + x8r, x8i, x9r, x9i, x10r, x10i, x11r, x11i, x12r, x12i, x13r, x13i, + x14r, x14i, x15r, x15i; + + x1r = a[2]; + x1i = a[3]; + x2r = a[4]; + x2i = a[5]; + x3r = a[6]; + x3i = a[7]; + x4r = a[8]; + x4i = a[9]; + x5r = a[10]; + x5i = a[11]; + x6r = a[12]; + x6i = a[13]; + x7r = a[14]; + x7i = a[15]; + x8r = a[16]; + x8i = a[17]; + x9r = a[18]; + x9i = a[19]; + x10r = a[20]; + x10i = a[21]; + x11r = a[22]; + x11i = a[23]; + x12r = a[24]; + x12i = a[25]; + x13r = a[26]; + x13i = a[27]; + x14r = a[28]; + x14i = a[29]; + x15r = a[30]; + x15i = a[31]; + a[2] = x15r; + a[3] = x15i; + a[4] = x7r; + a[5] = x7i; + a[6] = x11r; + a[7] = x11i; + a[8] = x3r; + a[9] = x3i; + a[10] = x13r; + a[11] = x13i; + a[12] = x5r; + a[13] = x5i; + a[14] = x9r; + a[15] = x9i; + a[16] = x1r; + a[17] = x1i; + a[18] = x14r; + a[19] = x14i; + a[20] = x6r; + a[21] = x6i; + a[22] = x10r; + a[23] = x10i; + a[24] = x2r; + a[25] = x2i; + a[26] = x12r; + a[27] = x12i; + a[28] = x4r; + a[29] = x4i; + a[30] = x8r; + a[31] = x8i; +} + + +void bitrv208(double *a) { + double x1r, x1i, x3r, x3i, x4r, x4i, x6r, x6i; + + x1r = a[2]; + x1i = a[3]; + x3r = a[6]; + x3i = a[7]; + x4r = a[8]; + x4i = a[9]; + x6r = a[12]; + x6i = a[13]; + a[2] = x4r; + a[3] = x4i; + a[6] = x6r; + a[7] = x6i; + a[8] = x1r; + a[9] = x1i; + a[12] = x3r; + a[13] = x3i; +} + + +void bitrv208neg(double *a) { + double x1r, x1i, x2r, x2i, x3r, x3i, x4r, x4i, x5r, x5i, x6r, x6i, x7r, x7i; + + x1r = a[2]; + x1i = a[3]; + x2r = a[4]; + x2i = a[5]; + x3r = a[6]; + x3i = a[7]; + x4r = a[8]; + x4i = a[9]; + x5r = a[10]; + x5i = a[11]; + x6r = a[12]; + x6i = a[13]; + x7r = a[14]; + x7i = a[15]; + a[2] = x7r; + a[3] = x7i; + a[4] = x3r; + a[5] = x3i; + a[6] = x5r; + a[7] = x5i; + a[8] = x1r; + a[9] = x1i; + a[10] = x6r; + a[11] = x6i; + a[12] = x2r; + a[13] = x2i; + a[14] = x4r; + a[15] = x4i; +} + + +void cftf1st(int n, double *a, double *w) { + int j, j0, j1, j2, j3, k, m, mh; + double wn4r, csc1, csc3, wk1r, wk1i, wk3r, wk3i, wd1r, wd1i, wd3r, wd3i; + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, y1r, y1i, y2r, y2i, + y3r, y3i; + + mh = n >> 3; + m = 2 * mh; + j1 = m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[0] + a[j2]; + x0i = a[1] + a[j2 + 1]; + x1r = a[0] - a[j2]; + x1i = a[1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + a[j2] = x1r - x3i; + a[j2 + 1] = x1i + x3r; + a[j3] = x1r + x3i; + a[j3 + 1] = x1i - x3r; + wn4r = w[1]; + csc1 = w[2]; + csc3 = w[3]; + wd1r = 1; + wd1i = 0; + wd3r = 1; + wd3i = 0; + k = 0; + for (j = 2; j < mh - 2; j += 4) { + k += 4; + wk1r = csc1 * (wd1r + w[k]); + wk1i = csc1 * (wd1i + w[k + 1]); + wk3r = csc3 * (wd3r + w[k + 2]); + wk3i = csc3 * (wd3i + w[k + 3]); + wd1r = w[k]; + wd1i = w[k + 1]; + wd3r = w[k + 2]; + wd3i = w[k + 3]; + j1 = j + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j] + a[j2]; + x0i = a[j + 1] + a[j2 + 1]; + x1r = a[j] - a[j2]; + x1i = a[j + 1] - a[j2 + 1]; + y0r = a[j + 2] + a[j2 + 2]; + y0i = a[j + 3] + a[j2 + 3]; + y1r = a[j + 2] - a[j2 + 2]; + y1i = a[j + 3] - a[j2 + 3]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + y2r = a[j1 + 2] + a[j3 + 2]; + y2i = a[j1 + 3] + a[j3 + 3]; + y3r = a[j1 + 2] - a[j3 + 2]; + y3i = a[j1 + 3] - a[j3 + 3]; + a[j] = x0r + x2r; + a[j + 1] = x0i + x2i; + a[j + 2] = y0r + y2r; + a[j + 3] = y0i + y2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + a[j1 + 2] = y0r - y2r; + a[j1 + 3] = y0i - y2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wk1r * x0r - wk1i * x0i; + a[j2 + 1] = wk1r * x0i + wk1i * x0r; + x0r = y1r - y3i; + x0i = y1i + y3r; + a[j2 + 2] = wd1r * x0r - wd1i * x0i; + a[j2 + 3] = wd1r * x0i + wd1i * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = wk3r * x0r + wk3i * x0i; + a[j3 + 1] = wk3r * x0i - wk3i * x0r; + x0r = y1r + y3i; + x0i = y1i - y3r; + a[j3 + 2] = wd3r * x0r + wd3i * x0i; + a[j3 + 3] = wd3r * x0i - wd3i * x0r; + j0 = m - j; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] + a[j2]; + x0i = a[j0 + 1] + a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = a[j0 + 1] - a[j2 + 1]; + y0r = a[j0 - 2] + a[j2 - 2]; + y0i = a[j0 - 1] + a[j2 - 1]; + y1r = a[j0 - 2] - a[j2 - 2]; + y1i = a[j0 - 1] - a[j2 - 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + y2r = a[j1 - 2] + a[j3 - 2]; + y2i = a[j1 - 1] + a[j3 - 1]; + y3r = a[j1 - 2] - a[j3 - 2]; + y3i = a[j1 - 1] - a[j3 - 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i + x2i; + a[j0 - 2] = y0r + y2r; + a[j0 - 1] = y0i + y2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + a[j1 - 2] = y0r - y2r; + a[j1 - 1] = y0i - y2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wk1i * x0r - wk1r * x0i; + a[j2 + 1] = wk1i * x0i + wk1r * x0r; + x0r = y1r - y3i; + x0i = y1i + y3r; + a[j2 - 2] = wd1i * x0r - wd1r * x0i; + a[j2 - 1] = wd1i * x0i + wd1r * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = wk3i * x0r + wk3r * x0i; + a[j3 + 1] = wk3i * x0i - wk3r * x0r; + x0r = y1r + y3i; + x0i = y1i - y3r; + a[j3 - 2] = wd3i * x0r + wd3r * x0i; + a[j3 - 1] = wd3i * x0i - wd3r * x0r; + } + wk1r = csc1 * (wd1r + wn4r); + wk1i = csc1 * (wd1i + wn4r); + wk3r = csc3 * (wd3r - wn4r); + wk3i = csc3 * (wd3i - wn4r); + j0 = mh; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0 - 2] + a[j2 - 2]; + x0i = a[j0 - 1] + a[j2 - 1]; + x1r = a[j0 - 2] - a[j2 - 2]; + x1i = a[j0 - 1] - a[j2 - 1]; + x2r = a[j1 - 2] + a[j3 - 2]; + x2i = a[j1 - 1] + a[j3 - 1]; + x3r = a[j1 - 2] - a[j3 - 2]; + x3i = a[j1 - 1] - a[j3 - 1]; + a[j0 - 2] = x0r + x2r; + a[j0 - 1] = x0i + x2i; + a[j1 - 2] = x0r - x2r; + a[j1 - 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2 - 2] = wk1r * x0r - wk1i * x0i; + a[j2 - 1] = wk1r * x0i + wk1i * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3 - 2] = wk3r * x0r + wk3i * x0i; + a[j3 - 1] = wk3r * x0i - wk3i * x0r; + x0r = a[j0] + a[j2]; + x0i = a[j0 + 1] + a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = a[j0 + 1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wn4r * (x0r - x0i); + a[j2 + 1] = wn4r * (x0i + x0r); + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = -wn4r * (x0r + x0i); + a[j3 + 1] = -wn4r * (x0i - x0r); + x0r = a[j0 + 2] + a[j2 + 2]; + x0i = a[j0 + 3] + a[j2 + 3]; + x1r = a[j0 + 2] - a[j2 + 2]; + x1i = a[j0 + 3] - a[j2 + 3]; + x2r = a[j1 + 2] + a[j3 + 2]; + x2i = a[j1 + 3] + a[j3 + 3]; + x3r = a[j1 + 2] - a[j3 + 2]; + x3i = a[j1 + 3] - a[j3 + 3]; + a[j0 + 2] = x0r + x2r; + a[j0 + 3] = x0i + x2i; + a[j1 + 2] = x0r - x2r; + a[j1 + 3] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2 + 2] = wk1i * x0r - wk1r * x0i; + a[j2 + 3] = wk1i * x0i + wk1r * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3 + 2] = wk3i * x0r + wk3r * x0i; + a[j3 + 3] = wk3i * x0i - wk3r * x0r; +} + + +void cftb1st(int n, double *a, double *w) { + int j, j0, j1, j2, j3, k, m, mh; + double wn4r, csc1, csc3, wk1r, wk1i, wk3r, wk3i, wd1r, wd1i, wd3r, wd3i; + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, y1r, y1i, y2r, y2i, + y3r, y3i; + + mh = n >> 3; + m = 2 * mh; + j1 = m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[0] + a[j2]; + x0i = -a[1] - a[j2 + 1]; + x1r = a[0] - a[j2]; + x1i = -a[1] + a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[0] = x0r + x2r; + a[1] = x0i - x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i + x2i; + a[j2] = x1r + x3i; + a[j2 + 1] = x1i + x3r; + a[j3] = x1r - x3i; + a[j3 + 1] = x1i - x3r; + wn4r = w[1]; + csc1 = w[2]; + csc3 = w[3]; + wd1r = 1; + wd1i = 0; + wd3r = 1; + wd3i = 0; + k = 0; + for (j = 2; j < mh - 2; j += 4) { + k += 4; + wk1r = csc1 * (wd1r + w[k]); + wk1i = csc1 * (wd1i + w[k + 1]); + wk3r = csc3 * (wd3r + w[k + 2]); + wk3i = csc3 * (wd3i + w[k + 3]); + wd1r = w[k]; + wd1i = w[k + 1]; + wd3r = w[k + 2]; + wd3i = w[k + 3]; + j1 = j + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j] + a[j2]; + x0i = -a[j + 1] - a[j2 + 1]; + x1r = a[j] - a[j2]; + x1i = -a[j + 1] + a[j2 + 1]; + y0r = a[j + 2] + a[j2 + 2]; + y0i = -a[j + 3] - a[j2 + 3]; + y1r = a[j + 2] - a[j2 + 2]; + y1i = -a[j + 3] + a[j2 + 3]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + y2r = a[j1 + 2] + a[j3 + 2]; + y2i = a[j1 + 3] + a[j3 + 3]; + y3r = a[j1 + 2] - a[j3 + 2]; + y3i = a[j1 + 3] - a[j3 + 3]; + a[j] = x0r + x2r; + a[j + 1] = x0i - x2i; + a[j + 2] = y0r + y2r; + a[j + 3] = y0i - y2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i + x2i; + a[j1 + 2] = y0r - y2r; + a[j1 + 3] = y0i + y2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2] = wk1r * x0r - wk1i * x0i; + a[j2 + 1] = wk1r * x0i + wk1i * x0r; + x0r = y1r + y3i; + x0i = y1i + y3r; + a[j2 + 2] = wd1r * x0r - wd1i * x0i; + a[j2 + 3] = wd1r * x0i + wd1i * x0r; + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3] = wk3r * x0r + wk3i * x0i; + a[j3 + 1] = wk3r * x0i - wk3i * x0r; + x0r = y1r - y3i; + x0i = y1i - y3r; + a[j3 + 2] = wd3r * x0r + wd3i * x0i; + a[j3 + 3] = wd3r * x0i - wd3i * x0r; + j0 = m - j; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] + a[j2]; + x0i = -a[j0 + 1] - a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = -a[j0 + 1] + a[j2 + 1]; + y0r = a[j0 - 2] + a[j2 - 2]; + y0i = -a[j0 - 1] - a[j2 - 1]; + y1r = a[j0 - 2] - a[j2 - 2]; + y1i = -a[j0 - 1] + a[j2 - 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + y2r = a[j1 - 2] + a[j3 - 2]; + y2i = a[j1 - 1] + a[j3 - 1]; + y3r = a[j1 - 2] - a[j3 - 2]; + y3i = a[j1 - 1] - a[j3 - 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i - x2i; + a[j0 - 2] = y0r + y2r; + a[j0 - 1] = y0i - y2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i + x2i; + a[j1 - 2] = y0r - y2r; + a[j1 - 1] = y0i + y2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2] = wk1i * x0r - wk1r * x0i; + a[j2 + 1] = wk1i * x0i + wk1r * x0r; + x0r = y1r + y3i; + x0i = y1i + y3r; + a[j2 - 2] = wd1i * x0r - wd1r * x0i; + a[j2 - 1] = wd1i * x0i + wd1r * x0r; + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3] = wk3i * x0r + wk3r * x0i; + a[j3 + 1] = wk3i * x0i - wk3r * x0r; + x0r = y1r - y3i; + x0i = y1i - y3r; + a[j3 - 2] = wd3i * x0r + wd3r * x0i; + a[j3 - 1] = wd3i * x0i - wd3r * x0r; + } + wk1r = csc1 * (wd1r + wn4r); + wk1i = csc1 * (wd1i + wn4r); + wk3r = csc3 * (wd3r - wn4r); + wk3i = csc3 * (wd3i - wn4r); + j0 = mh; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0 - 2] + a[j2 - 2]; + x0i = -a[j0 - 1] - a[j2 - 1]; + x1r = a[j0 - 2] - a[j2 - 2]; + x1i = -a[j0 - 1] + a[j2 - 1]; + x2r = a[j1 - 2] + a[j3 - 2]; + x2i = a[j1 - 1] + a[j3 - 1]; + x3r = a[j1 - 2] - a[j3 - 2]; + x3i = a[j1 - 1] - a[j3 - 1]; + a[j0 - 2] = x0r + x2r; + a[j0 - 1] = x0i - x2i; + a[j1 - 2] = x0r - x2r; + a[j1 - 1] = x0i + x2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2 - 2] = wk1r * x0r - wk1i * x0i; + a[j2 - 1] = wk1r * x0i + wk1i * x0r; + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3 - 2] = wk3r * x0r + wk3i * x0i; + a[j3 - 1] = wk3r * x0i - wk3i * x0r; + x0r = a[j0] + a[j2]; + x0i = -a[j0 + 1] - a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = -a[j0 + 1] + a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i - x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i + x2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2] = wn4r * (x0r - x0i); + a[j2 + 1] = wn4r * (x0i + x0r); + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3] = -wn4r * (x0r + x0i); + a[j3 + 1] = -wn4r * (x0i - x0r); + x0r = a[j0 + 2] + a[j2 + 2]; + x0i = -a[j0 + 3] - a[j2 + 3]; + x1r = a[j0 + 2] - a[j2 + 2]; + x1i = -a[j0 + 3] + a[j2 + 3]; + x2r = a[j1 + 2] + a[j3 + 2]; + x2i = a[j1 + 3] + a[j3 + 3]; + x3r = a[j1 + 2] - a[j3 + 2]; + x3i = a[j1 + 3] - a[j3 + 3]; + a[j0 + 2] = x0r + x2r; + a[j0 + 3] = x0i - x2i; + a[j1 + 2] = x0r - x2r; + a[j1 + 3] = x0i + x2i; + x0r = x1r + x3i; + x0i = x1i + x3r; + a[j2 + 2] = wk1i * x0r - wk1r * x0i; + a[j2 + 3] = wk1i * x0i + wk1r * x0r; + x0r = x1r - x3i; + x0i = x1i - x3r; + a[j3 + 2] = wk3i * x0r + wk3r * x0i; + a[j3 + 3] = wk3i * x0i - wk3r * x0r; +} + + +#ifdef USE_CDFT_THREADS +struct cdft_arg_st { + int n0; + int n; + double *a; + int nw; + double *w; +}; +typedef struct cdft_arg_st cdft_arg_t; + + +void cftrec4_th(int n, double *a, int nw, double *w) { + void *cftrec1_th(void *p); + void *cftrec2_th(void *p); + int i, idiv4, m, nthread; + cdft_thread_t th[4]; + cdft_arg_t ag[4]; + + nthread = 2; + idiv4 = 0; + m = n >> 1; + if (n > CDFT_4THREADS_BEGIN_N) { + nthread = 4; + idiv4 = 1; + m >>= 1; + } + for (i = 0; i < nthread; i++) { + ag[i].n0 = n; + ag[i].n = m; + ag[i].a = &a[i * m]; + ag[i].nw = nw; + ag[i].w = w; + if (i != idiv4) { + cdft_thread_create(&th[i], cftrec1_th, &ag[i]); + } else { + cdft_thread_create(&th[i], cftrec2_th, &ag[i]); + } + } + for (i = 0; i < nthread; i++) { + cdft_thread_wait(th[i]); + } +} + + +void *cftrec1_th(void *p) { + int cfttree(int n, int j, int k, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftmdl1(int n, double *a, double *w); + int isplt, j, k, m, n, n0, nw; + double *a, *w; + + n0 = ((cdft_arg_t *)p)->n0; + n = ((cdft_arg_t *)p)->n; + a = ((cdft_arg_t *)p)->a; + nw = ((cdft_arg_t *)p)->nw; + w = ((cdft_arg_t *)p)->w; + m = n0; + while (m > 512) { + m >>= 2; + cftmdl1(m, &a[n - m], &w[nw - (m >> 1)]); + } + cftleaf(m, 1, &a[n - m], nw, w); + k = 0; + for (j = n - m; j > 0; j -= m) { + k++; + isplt = cfttree(m, j, k, a, nw, w); + cftleaf(m, isplt, &a[j - m], nw, w); + } + return (void *)0; +} + + +void *cftrec2_th(void *p) { + int cfttree(int n, int j, int k, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftmdl2(int n, double *a, double *w); + int isplt, j, k, m, n, n0, nw; + double *a, *w; + + n0 = ((cdft_arg_t *)p)->n0; + n = ((cdft_arg_t *)p)->n; + a = ((cdft_arg_t *)p)->a; + nw = ((cdft_arg_t *)p)->nw; + w = ((cdft_arg_t *)p)->w; + k = 1; + m = n0; + while (m > 512) { + m >>= 2; + k <<= 2; + cftmdl2(m, &a[n - m], &w[nw - m]); + } + cftleaf(m, 0, &a[n - m], nw, w); + k >>= 1; + for (j = n - m; j > 0; j -= m) { + k++; + isplt = cfttree(m, j, k, a, nw, w); + cftleaf(m, isplt, &a[j - m], nw, w); + } + return (void *)0; +} +#endif /* USE_CDFT_THREADS */ + + +void cftrec4(int n, double *a, int nw, double *w) { + int cfttree(int n, int j, int k, double *a, int nw, double *w); + void cftleaf(int n, int isplt, double *a, int nw, double *w); + void cftmdl1(int n, double *a, double *w); + int isplt, j, k, m; + + m = n; + while (m > 512) { + m >>= 2; + cftmdl1(m, &a[n - m], &w[nw - (m >> 1)]); + } + cftleaf(m, 1, &a[n - m], nw, w); + k = 0; + for (j = n - m; j > 0; j -= m) { + k++; + isplt = cfttree(m, j, k, a, nw, w); + cftleaf(m, isplt, &a[j - m], nw, w); + } +} + + +int cfttree(int n, int j, int k, double *a, int nw, double *w) { + void cftmdl1(int n, double *a, double *w); + void cftmdl2(int n, double *a, double *w); + int i, isplt, m; + + if ((k & 3) != 0) { + isplt = k & 1; + if (isplt != 0) { + cftmdl1(n, &a[j - n], &w[nw - (n >> 1)]); + } else { + cftmdl2(n, &a[j - n], &w[nw - n]); + } + } else { + m = n; + for (i = k; (i & 3) == 0; i >>= 2) { + m <<= 2; + } + isplt = i & 1; + if (isplt != 0) { + while (m > 128) { + cftmdl1(m, &a[j - m], &w[nw - (m >> 1)]); + m >>= 2; + } + } else { + while (m > 128) { + cftmdl2(m, &a[j - m], &w[nw - m]); + m >>= 2; + } + } + } + return isplt; +} + + +void cftleaf(int n, int isplt, double *a, int nw, double *w) { + void cftmdl1(int n, double *a, double *w); + void cftmdl2(int n, double *a, double *w); + void cftf161(double *a, double *w); + void cftf162(double *a, double *w); + void cftf081(double *a, double *w); + void cftf082(double *a, double *w); + + if (n == 512) { + cftmdl1(128, a, &w[nw - 64]); + cftf161(a, &w[nw - 8]); + cftf162(&a[32], &w[nw - 32]); + cftf161(&a[64], &w[nw - 8]); + cftf161(&a[96], &w[nw - 8]); + cftmdl2(128, &a[128], &w[nw - 128]); + cftf161(&a[128], &w[nw - 8]); + cftf162(&a[160], &w[nw - 32]); + cftf161(&a[192], &w[nw - 8]); + cftf162(&a[224], &w[nw - 32]); + cftmdl1(128, &a[256], &w[nw - 64]); + cftf161(&a[256], &w[nw - 8]); + cftf162(&a[288], &w[nw - 32]); + cftf161(&a[320], &w[nw - 8]); + cftf161(&a[352], &w[nw - 8]); + if (isplt != 0) { + cftmdl1(128, &a[384], &w[nw - 64]); + cftf161(&a[480], &w[nw - 8]); + } else { + cftmdl2(128, &a[384], &w[nw - 128]); + cftf162(&a[480], &w[nw - 32]); + } + cftf161(&a[384], &w[nw - 8]); + cftf162(&a[416], &w[nw - 32]); + cftf161(&a[448], &w[nw - 8]); + } else { + cftmdl1(64, a, &w[nw - 32]); + cftf081(a, &w[nw - 8]); + cftf082(&a[16], &w[nw - 8]); + cftf081(&a[32], &w[nw - 8]); + cftf081(&a[48], &w[nw - 8]); + cftmdl2(64, &a[64], &w[nw - 64]); + cftf081(&a[64], &w[nw - 8]); + cftf082(&a[80], &w[nw - 8]); + cftf081(&a[96], &w[nw - 8]); + cftf082(&a[112], &w[nw - 8]); + cftmdl1(64, &a[128], &w[nw - 32]); + cftf081(&a[128], &w[nw - 8]); + cftf082(&a[144], &w[nw - 8]); + cftf081(&a[160], &w[nw - 8]); + cftf081(&a[176], &w[nw - 8]); + if (isplt != 0) { + cftmdl1(64, &a[192], &w[nw - 32]); + cftf081(&a[240], &w[nw - 8]); + } else { + cftmdl2(64, &a[192], &w[nw - 64]); + cftf082(&a[240], &w[nw - 8]); + } + cftf081(&a[192], &w[nw - 8]); + cftf082(&a[208], &w[nw - 8]); + cftf081(&a[224], &w[nw - 8]); + } +} + + +void cftmdl1(int n, double *a, double *w) { + int j, j0, j1, j2, j3, k, m, mh; + double wn4r, wk1r, wk1i, wk3r, wk3i; + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i; + + mh = n >> 3; + m = 2 * mh; + j1 = m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[0] + a[j2]; + x0i = a[1] + a[j2 + 1]; + x1r = a[0] - a[j2]; + x1i = a[1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + a[j2] = x1r - x3i; + a[j2 + 1] = x1i + x3r; + a[j3] = x1r + x3i; + a[j3 + 1] = x1i - x3r; + wn4r = w[1]; + k = 0; + for (j = 2; j < mh; j += 2) { + k += 4; + wk1r = w[k]; + wk1i = w[k + 1]; + wk3r = w[k + 2]; + wk3i = w[k + 3]; + j1 = j + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j] + a[j2]; + x0i = a[j + 1] + a[j2 + 1]; + x1r = a[j] - a[j2]; + x1i = a[j + 1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j] = x0r + x2r; + a[j + 1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wk1r * x0r - wk1i * x0i; + a[j2 + 1] = wk1r * x0i + wk1i * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = wk3r * x0r + wk3i * x0i; + a[j3 + 1] = wk3r * x0i - wk3i * x0r; + j0 = m - j; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] + a[j2]; + x0i = a[j0 + 1] + a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = a[j0 + 1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wk1i * x0r - wk1r * x0i; + a[j2 + 1] = wk1i * x0i + wk1r * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = wk3i * x0r + wk3r * x0i; + a[j3 + 1] = wk3i * x0i - wk3r * x0r; + } + j0 = mh; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] + a[j2]; + x0i = a[j0 + 1] + a[j2 + 1]; + x1r = a[j0] - a[j2]; + x1i = a[j0 + 1] - a[j2 + 1]; + x2r = a[j1] + a[j3]; + x2i = a[j1 + 1] + a[j3 + 1]; + x3r = a[j1] - a[j3]; + x3i = a[j1 + 1] - a[j3 + 1]; + a[j0] = x0r + x2r; + a[j0 + 1] = x0i + x2i; + a[j1] = x0r - x2r; + a[j1 + 1] = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + a[j2] = wn4r * (x0r - x0i); + a[j2 + 1] = wn4r * (x0i + x0r); + x0r = x1r + x3i; + x0i = x1i - x3r; + a[j3] = -wn4r * (x0r + x0i); + a[j3 + 1] = -wn4r * (x0i - x0r); +} + + +void cftmdl2(int n, double *a, double *w) { + int j, j0, j1, j2, j3, k, kr, m, mh; + double wn4r, wk1r, wk1i, wk3r, wk3i, wd1r, wd1i, wd3r, wd3i; + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, y2r, y2i; + + mh = n >> 3; + m = 2 * mh; + wn4r = w[1]; + j1 = m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[0] - a[j2 + 1]; + x0i = a[1] + a[j2]; + x1r = a[0] + a[j2 + 1]; + x1i = a[1] - a[j2]; + x2r = a[j1] - a[j3 + 1]; + x2i = a[j1 + 1] + a[j3]; + x3r = a[j1] + a[j3 + 1]; + x3i = a[j1 + 1] - a[j3]; + y0r = wn4r * (x2r - x2i); + y0i = wn4r * (x2i + x2r); + a[0] = x0r + y0r; + a[1] = x0i + y0i; + a[j1] = x0r - y0r; + a[j1 + 1] = x0i - y0i; + y0r = wn4r * (x3r - x3i); + y0i = wn4r * (x3i + x3r); + a[j2] = x1r - y0i; + a[j2 + 1] = x1i + y0r; + a[j3] = x1r + y0i; + a[j3 + 1] = x1i - y0r; + k = 0; + kr = 2 * m; + for (j = 2; j < mh; j += 2) { + k += 4; + wk1r = w[k]; + wk1i = w[k + 1]; + wk3r = w[k + 2]; + wk3i = w[k + 3]; + kr -= 4; + wd1i = w[kr]; + wd1r = w[kr + 1]; + wd3i = w[kr + 2]; + wd3r = w[kr + 3]; + j1 = j + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j] - a[j2 + 1]; + x0i = a[j + 1] + a[j2]; + x1r = a[j] + a[j2 + 1]; + x1i = a[j + 1] - a[j2]; + x2r = a[j1] - a[j3 + 1]; + x2i = a[j1 + 1] + a[j3]; + x3r = a[j1] + a[j3 + 1]; + x3i = a[j1 + 1] - a[j3]; + y0r = wk1r * x0r - wk1i * x0i; + y0i = wk1r * x0i + wk1i * x0r; + y2r = wd1r * x2r - wd1i * x2i; + y2i = wd1r * x2i + wd1i * x2r; + a[j] = y0r + y2r; + a[j + 1] = y0i + y2i; + a[j1] = y0r - y2r; + a[j1 + 1] = y0i - y2i; + y0r = wk3r * x1r + wk3i * x1i; + y0i = wk3r * x1i - wk3i * x1r; + y2r = wd3r * x3r + wd3i * x3i; + y2i = wd3r * x3i - wd3i * x3r; + a[j2] = y0r + y2r; + a[j2 + 1] = y0i + y2i; + a[j3] = y0r - y2r; + a[j3 + 1] = y0i - y2i; + j0 = m - j; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] - a[j2 + 1]; + x0i = a[j0 + 1] + a[j2]; + x1r = a[j0] + a[j2 + 1]; + x1i = a[j0 + 1] - a[j2]; + x2r = a[j1] - a[j3 + 1]; + x2i = a[j1 + 1] + a[j3]; + x3r = a[j1] + a[j3 + 1]; + x3i = a[j1 + 1] - a[j3]; + y0r = wd1i * x0r - wd1r * x0i; + y0i = wd1i * x0i + wd1r * x0r; + y2r = wk1i * x2r - wk1r * x2i; + y2i = wk1i * x2i + wk1r * x2r; + a[j0] = y0r + y2r; + a[j0 + 1] = y0i + y2i; + a[j1] = y0r - y2r; + a[j1 + 1] = y0i - y2i; + y0r = wd3i * x1r + wd3r * x1i; + y0i = wd3i * x1i - wd3r * x1r; + y2r = wk3i * x3r + wk3r * x3i; + y2i = wk3i * x3i - wk3r * x3r; + a[j2] = y0r + y2r; + a[j2 + 1] = y0i + y2i; + a[j3] = y0r - y2r; + a[j3 + 1] = y0i - y2i; + } + wk1r = w[m]; + wk1i = w[m + 1]; + j0 = mh; + j1 = j0 + m; + j2 = j1 + m; + j3 = j2 + m; + x0r = a[j0] - a[j2 + 1]; + x0i = a[j0 + 1] + a[j2]; + x1r = a[j0] + a[j2 + 1]; + x1i = a[j0 + 1] - a[j2]; + x2r = a[j1] - a[j3 + 1]; + x2i = a[j1 + 1] + a[j3]; + x3r = a[j1] + a[j3 + 1]; + x3i = a[j1 + 1] - a[j3]; + y0r = wk1r * x0r - wk1i * x0i; + y0i = wk1r * x0i + wk1i * x0r; + y2r = wk1i * x2r - wk1r * x2i; + y2i = wk1i * x2i + wk1r * x2r; + a[j0] = y0r + y2r; + a[j0 + 1] = y0i + y2i; + a[j1] = y0r - y2r; + a[j1 + 1] = y0i - y2i; + y0r = wk1i * x1r - wk1r * x1i; + y0i = wk1i * x1i + wk1r * x1r; + y2r = wk1r * x3r - wk1i * x3i; + y2i = wk1r * x3i + wk1i * x3r; + a[j2] = y0r - y2r; + a[j2 + 1] = y0i - y2i; + a[j3] = y0r + y2r; + a[j3 + 1] = y0i + y2i; +} + + +void cftfx41(int n, double *a, int nw, double *w) { + void cftf161(double *a, double *w); + void cftf162(double *a, double *w); + void cftf081(double *a, double *w); + void cftf082(double *a, double *w); + + if (n == 128) { + cftf161(a, &w[nw - 8]); + cftf162(&a[32], &w[nw - 32]); + cftf161(&a[64], &w[nw - 8]); + cftf161(&a[96], &w[nw - 8]); + } else { + cftf081(a, &w[nw - 8]); + cftf082(&a[16], &w[nw - 8]); + cftf081(&a[32], &w[nw - 8]); + cftf081(&a[48], &w[nw - 8]); + } +} + + +void cftf161(double *a, double *w) { + double wn4r, wk1r, wk1i, x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, + y1r, y1i, y2r, y2i, y3r, y3i, y4r, y4i, y5r, y5i, y6r, y6i, y7r, y7i, + y8r, y8i, y9r, y9i, y10r, y10i, y11r, y11i, y12r, y12i, y13r, y13i, + y14r, y14i, y15r, y15i; + + wn4r = w[1]; + wk1r = w[2]; + wk1i = w[3]; + x0r = a[0] + a[16]; + x0i = a[1] + a[17]; + x1r = a[0] - a[16]; + x1i = a[1] - a[17]; + x2r = a[8] + a[24]; + x2i = a[9] + a[25]; + x3r = a[8] - a[24]; + x3i = a[9] - a[25]; + y0r = x0r + x2r; + y0i = x0i + x2i; + y4r = x0r - x2r; + y4i = x0i - x2i; + y8r = x1r - x3i; + y8i = x1i + x3r; + y12r = x1r + x3i; + y12i = x1i - x3r; + x0r = a[2] + a[18]; + x0i = a[3] + a[19]; + x1r = a[2] - a[18]; + x1i = a[3] - a[19]; + x2r = a[10] + a[26]; + x2i = a[11] + a[27]; + x3r = a[10] - a[26]; + x3i = a[11] - a[27]; + y1r = x0r + x2r; + y1i = x0i + x2i; + y5r = x0r - x2r; + y5i = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + y9r = wk1r * x0r - wk1i * x0i; + y9i = wk1r * x0i + wk1i * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + y13r = wk1i * x0r - wk1r * x0i; + y13i = wk1i * x0i + wk1r * x0r; + x0r = a[4] + a[20]; + x0i = a[5] + a[21]; + x1r = a[4] - a[20]; + x1i = a[5] - a[21]; + x2r = a[12] + a[28]; + x2i = a[13] + a[29]; + x3r = a[12] - a[28]; + x3i = a[13] - a[29]; + y2r = x0r + x2r; + y2i = x0i + x2i; + y6r = x0r - x2r; + y6i = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + y10r = wn4r * (x0r - x0i); + y10i = wn4r * (x0i + x0r); + x0r = x1r + x3i; + x0i = x1i - x3r; + y14r = wn4r * (x0r + x0i); + y14i = wn4r * (x0i - x0r); + x0r = a[6] + a[22]; + x0i = a[7] + a[23]; + x1r = a[6] - a[22]; + x1i = a[7] - a[23]; + x2r = a[14] + a[30]; + x2i = a[15] + a[31]; + x3r = a[14] - a[30]; + x3i = a[15] - a[31]; + y3r = x0r + x2r; + y3i = x0i + x2i; + y7r = x0r - x2r; + y7i = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + y11r = wk1i * x0r - wk1r * x0i; + y11i = wk1i * x0i + wk1r * x0r; + x0r = x1r + x3i; + x0i = x1i - x3r; + y15r = wk1r * x0r - wk1i * x0i; + y15i = wk1r * x0i + wk1i * x0r; + x0r = y12r - y14r; + x0i = y12i - y14i; + x1r = y12r + y14r; + x1i = y12i + y14i; + x2r = y13r - y15r; + x2i = y13i - y15i; + x3r = y13r + y15r; + x3i = y13i + y15i; + a[24] = x0r + x2r; + a[25] = x0i + x2i; + a[26] = x0r - x2r; + a[27] = x0i - x2i; + a[28] = x1r - x3i; + a[29] = x1i + x3r; + a[30] = x1r + x3i; + a[31] = x1i - x3r; + x0r = y8r + y10r; + x0i = y8i + y10i; + x1r = y8r - y10r; + x1i = y8i - y10i; + x2r = y9r + y11r; + x2i = y9i + y11i; + x3r = y9r - y11r; + x3i = y9i - y11i; + a[16] = x0r + x2r; + a[17] = x0i + x2i; + a[18] = x0r - x2r; + a[19] = x0i - x2i; + a[20] = x1r - x3i; + a[21] = x1i + x3r; + a[22] = x1r + x3i; + a[23] = x1i - x3r; + x0r = y5r - y7i; + x0i = y5i + y7r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + x0r = y5r + y7i; + x0i = y5i - y7r; + x3r = wn4r * (x0r - x0i); + x3i = wn4r * (x0i + x0r); + x0r = y4r - y6i; + x0i = y4i + y6r; + x1r = y4r + y6i; + x1i = y4i - y6r; + a[8] = x0r + x2r; + a[9] = x0i + x2i; + a[10] = x0r - x2r; + a[11] = x0i - x2i; + a[12] = x1r - x3i; + a[13] = x1i + x3r; + a[14] = x1r + x3i; + a[15] = x1i - x3r; + x0r = y0r + y2r; + x0i = y0i + y2i; + x1r = y0r - y2r; + x1i = y0i - y2i; + x2r = y1r + y3r; + x2i = y1i + y3i; + x3r = y1r - y3r; + x3i = y1i - y3i; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[2] = x0r - x2r; + a[3] = x0i - x2i; + a[4] = x1r - x3i; + a[5] = x1i + x3r; + a[6] = x1r + x3i; + a[7] = x1i - x3r; +} + + +void cftf162(double *a, double *w) { + double wn4r, wk1r, wk1i, wk2r, wk2i, wk3r, wk3i, x0r, x0i, x1r, x1i, x2r, + x2i, y0r, y0i, y1r, y1i, y2r, y2i, y3r, y3i, y4r, y4i, y5r, y5i, y6r, + y6i, y7r, y7i, y8r, y8i, y9r, y9i, y10r, y10i, y11r, y11i, y12r, y12i, + y13r, y13i, y14r, y14i, y15r, y15i; + + wn4r = w[1]; + wk1r = w[4]; + wk1i = w[5]; + wk3r = w[6]; + wk3i = -w[7]; + wk2r = w[8]; + wk2i = w[9]; + x1r = a[0] - a[17]; + x1i = a[1] + a[16]; + x0r = a[8] - a[25]; + x0i = a[9] + a[24]; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + y0r = x1r + x2r; + y0i = x1i + x2i; + y4r = x1r - x2r; + y4i = x1i - x2i; + x1r = a[0] + a[17]; + x1i = a[1] - a[16]; + x0r = a[8] + a[25]; + x0i = a[9] - a[24]; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + y8r = x1r - x2i; + y8i = x1i + x2r; + y12r = x1r + x2i; + y12i = x1i - x2r; + x0r = a[2] - a[19]; + x0i = a[3] + a[18]; + x1r = wk1r * x0r - wk1i * x0i; + x1i = wk1r * x0i + wk1i * x0r; + x0r = a[10] - a[27]; + x0i = a[11] + a[26]; + x2r = wk3i * x0r - wk3r * x0i; + x2i = wk3i * x0i + wk3r * x0r; + y1r = x1r + x2r; + y1i = x1i + x2i; + y5r = x1r - x2r; + y5i = x1i - x2i; + x0r = a[2] + a[19]; + x0i = a[3] - a[18]; + x1r = wk3r * x0r - wk3i * x0i; + x1i = wk3r * x0i + wk3i * x0r; + x0r = a[10] + a[27]; + x0i = a[11] - a[26]; + x2r = wk1r * x0r + wk1i * x0i; + x2i = wk1r * x0i - wk1i * x0r; + y9r = x1r - x2r; + y9i = x1i - x2i; + y13r = x1r + x2r; + y13i = x1i + x2i; + x0r = a[4] - a[21]; + x0i = a[5] + a[20]; + x1r = wk2r * x0r - wk2i * x0i; + x1i = wk2r * x0i + wk2i * x0r; + x0r = a[12] - a[29]; + x0i = a[13] + a[28]; + x2r = wk2i * x0r - wk2r * x0i; + x2i = wk2i * x0i + wk2r * x0r; + y2r = x1r + x2r; + y2i = x1i + x2i; + y6r = x1r - x2r; + y6i = x1i - x2i; + x0r = a[4] + a[21]; + x0i = a[5] - a[20]; + x1r = wk2i * x0r - wk2r * x0i; + x1i = wk2i * x0i + wk2r * x0r; + x0r = a[12] + a[29]; + x0i = a[13] - a[28]; + x2r = wk2r * x0r - wk2i * x0i; + x2i = wk2r * x0i + wk2i * x0r; + y10r = x1r - x2r; + y10i = x1i - x2i; + y14r = x1r + x2r; + y14i = x1i + x2i; + x0r = a[6] - a[23]; + x0i = a[7] + a[22]; + x1r = wk3r * x0r - wk3i * x0i; + x1i = wk3r * x0i + wk3i * x0r; + x0r = a[14] - a[31]; + x0i = a[15] + a[30]; + x2r = wk1i * x0r - wk1r * x0i; + x2i = wk1i * x0i + wk1r * x0r; + y3r = x1r + x2r; + y3i = x1i + x2i; + y7r = x1r - x2r; + y7i = x1i - x2i; + x0r = a[6] + a[23]; + x0i = a[7] - a[22]; + x1r = wk1i * x0r + wk1r * x0i; + x1i = wk1i * x0i - wk1r * x0r; + x0r = a[14] + a[31]; + x0i = a[15] - a[30]; + x2r = wk3i * x0r - wk3r * x0i; + x2i = wk3i * x0i + wk3r * x0r; + y11r = x1r + x2r; + y11i = x1i + x2i; + y15r = x1r - x2r; + y15i = x1i - x2i; + x1r = y0r + y2r; + x1i = y0i + y2i; + x2r = y1r + y3r; + x2i = y1i + y3i; + a[0] = x1r + x2r; + a[1] = x1i + x2i; + a[2] = x1r - x2r; + a[3] = x1i - x2i; + x1r = y0r - y2r; + x1i = y0i - y2i; + x2r = y1r - y3r; + x2i = y1i - y3i; + a[4] = x1r - x2i; + a[5] = x1i + x2r; + a[6] = x1r + x2i; + a[7] = x1i - x2r; + x1r = y4r - y6i; + x1i = y4i + y6r; + x0r = y5r - y7i; + x0i = y5i + y7r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + a[8] = x1r + x2r; + a[9] = x1i + x2i; + a[10] = x1r - x2r; + a[11] = x1i - x2i; + x1r = y4r + y6i; + x1i = y4i - y6r; + x0r = y5r + y7i; + x0i = y5i - y7r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + a[12] = x1r - x2i; + a[13] = x1i + x2r; + a[14] = x1r + x2i; + a[15] = x1i - x2r; + x1r = y8r + y10r; + x1i = y8i + y10i; + x2r = y9r - y11r; + x2i = y9i - y11i; + a[16] = x1r + x2r; + a[17] = x1i + x2i; + a[18] = x1r - x2r; + a[19] = x1i - x2i; + x1r = y8r - y10r; + x1i = y8i - y10i; + x2r = y9r + y11r; + x2i = y9i + y11i; + a[20] = x1r - x2i; + a[21] = x1i + x2r; + a[22] = x1r + x2i; + a[23] = x1i - x2r; + x1r = y12r - y14i; + x1i = y12i + y14r; + x0r = y13r + y15i; + x0i = y13i - y15r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + a[24] = x1r + x2r; + a[25] = x1i + x2i; + a[26] = x1r - x2r; + a[27] = x1i - x2i; + x1r = y12r + y14i; + x1i = y12i - y14r; + x0r = y13r - y15i; + x0i = y13i + y15r; + x2r = wn4r * (x0r - x0i); + x2i = wn4r * (x0i + x0r); + a[28] = x1r - x2i; + a[29] = x1i + x2r; + a[30] = x1r + x2i; + a[31] = x1i - x2r; +} + + +void cftf081(double *a, double *w) { + double wn4r, x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i, y0r, y0i, y1r, y1i, + y2r, y2i, y3r, y3i, y4r, y4i, y5r, y5i, y6r, y6i, y7r, y7i; + + wn4r = w[1]; + x0r = a[0] + a[8]; + x0i = a[1] + a[9]; + x1r = a[0] - a[8]; + x1i = a[1] - a[9]; + x2r = a[4] + a[12]; + x2i = a[5] + a[13]; + x3r = a[4] - a[12]; + x3i = a[5] - a[13]; + y0r = x0r + x2r; + y0i = x0i + x2i; + y2r = x0r - x2r; + y2i = x0i - x2i; + y1r = x1r - x3i; + y1i = x1i + x3r; + y3r = x1r + x3i; + y3i = x1i - x3r; + x0r = a[2] + a[10]; + x0i = a[3] + a[11]; + x1r = a[2] - a[10]; + x1i = a[3] - a[11]; + x2r = a[6] + a[14]; + x2i = a[7] + a[15]; + x3r = a[6] - a[14]; + x3i = a[7] - a[15]; + y4r = x0r + x2r; + y4i = x0i + x2i; + y6r = x0r - x2r; + y6i = x0i - x2i; + x0r = x1r - x3i; + x0i = x1i + x3r; + x2r = x1r + x3i; + x2i = x1i - x3r; + y5r = wn4r * (x0r - x0i); + y5i = wn4r * (x0r + x0i); + y7r = wn4r * (x2r - x2i); + y7i = wn4r * (x2r + x2i); + a[8] = y1r + y5r; + a[9] = y1i + y5i; + a[10] = y1r - y5r; + a[11] = y1i - y5i; + a[12] = y3r - y7i; + a[13] = y3i + y7r; + a[14] = y3r + y7i; + a[15] = y3i - y7r; + a[0] = y0r + y4r; + a[1] = y0i + y4i; + a[2] = y0r - y4r; + a[3] = y0i - y4i; + a[4] = y2r - y6i; + a[5] = y2i + y6r; + a[6] = y2r + y6i; + a[7] = y2i - y6r; +} + + +void cftf082(double *a, double *w) { + double wn4r, wk1r, wk1i, x0r, x0i, x1r, x1i, y0r, y0i, y1r, y1i, y2r, y2i, + y3r, y3i, y4r, y4i, y5r, y5i, y6r, y6i, y7r, y7i; + + wn4r = w[1]; + wk1r = w[2]; + wk1i = w[3]; + y0r = a[0] - a[9]; + y0i = a[1] + a[8]; + y1r = a[0] + a[9]; + y1i = a[1] - a[8]; + x0r = a[4] - a[13]; + x0i = a[5] + a[12]; + y2r = wn4r * (x0r - x0i); + y2i = wn4r * (x0i + x0r); + x0r = a[4] + a[13]; + x0i = a[5] - a[12]; + y3r = wn4r * (x0r - x0i); + y3i = wn4r * (x0i + x0r); + x0r = a[2] - a[11]; + x0i = a[3] + a[10]; + y4r = wk1r * x0r - wk1i * x0i; + y4i = wk1r * x0i + wk1i * x0r; + x0r = a[2] + a[11]; + x0i = a[3] - a[10]; + y5r = wk1i * x0r - wk1r * x0i; + y5i = wk1i * x0i + wk1r * x0r; + x0r = a[6] - a[15]; + x0i = a[7] + a[14]; + y6r = wk1i * x0r - wk1r * x0i; + y6i = wk1i * x0i + wk1r * x0r; + x0r = a[6] + a[15]; + x0i = a[7] - a[14]; + y7r = wk1r * x0r - wk1i * x0i; + y7i = wk1r * x0i + wk1i * x0r; + x0r = y0r + y2r; + x0i = y0i + y2i; + x1r = y4r + y6r; + x1i = y4i + y6i; + a[0] = x0r + x1r; + a[1] = x0i + x1i; + a[2] = x0r - x1r; + a[3] = x0i - x1i; + x0r = y0r - y2r; + x0i = y0i - y2i; + x1r = y4r - y6r; + x1i = y4i - y6i; + a[4] = x0r - x1i; + a[5] = x0i + x1r; + a[6] = x0r + x1i; + a[7] = x0i - x1r; + x0r = y1r - y3i; + x0i = y1i + y3r; + x1r = y5r - y7r; + x1i = y5i - y7i; + a[8] = x0r + x1r; + a[9] = x0i + x1i; + a[10] = x0r - x1r; + a[11] = x0i - x1i; + x0r = y1r + y3i; + x0i = y1i - y3r; + x1r = y5r + y7r; + x1i = y5i + y7i; + a[12] = x0r - x1i; + a[13] = x0i + x1r; + a[14] = x0r + x1i; + a[15] = x0i - x1r; +} + + +void cftf040(double *a) { + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i; + + x0r = a[0] + a[4]; + x0i = a[1] + a[5]; + x1r = a[0] - a[4]; + x1i = a[1] - a[5]; + x2r = a[2] + a[6]; + x2i = a[3] + a[7]; + x3r = a[2] - a[6]; + x3i = a[3] - a[7]; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[2] = x1r - x3i; + a[3] = x1i + x3r; + a[4] = x0r - x2r; + a[5] = x0i - x2i; + a[6] = x1r + x3i; + a[7] = x1i - x3r; +} + + +void cftb040(double *a) { + double x0r, x0i, x1r, x1i, x2r, x2i, x3r, x3i; + + x0r = a[0] + a[4]; + x0i = a[1] + a[5]; + x1r = a[0] - a[4]; + x1i = a[1] - a[5]; + x2r = a[2] + a[6]; + x2i = a[3] + a[7]; + x3r = a[2] - a[6]; + x3i = a[3] - a[7]; + a[0] = x0r + x2r; + a[1] = x0i + x2i; + a[2] = x1r + x3i; + a[3] = x1i - x3r; + a[4] = x0r - x2r; + a[5] = x0i - x2i; + a[6] = x1r - x3i; + a[7] = x1i + x3r; +} + + +void cftx020(double *a) { + double x0r, x0i; + + x0r = a[0] - a[2]; + x0i = a[1] - a[3]; + a[0] += a[2]; + a[1] += a[3]; + a[2] = x0r; + a[3] = x0i; +} + + +void rftfsub(int n, double *a, int nc, double *c) { + int j, k, kk, ks, m; + double wkr, wki, xr, xi, yr, yi; + + m = n >> 1; + ks = 2 * nc / m; + kk = 0; + for (j = 2; j < m; j += 2) { + k = n - j; + kk += ks; + wkr = 0.5 - c[nc - kk]; + wki = c[kk]; + xr = a[j] - a[k]; + xi = a[j + 1] + a[k + 1]; + yr = wkr * xr - wki * xi; + yi = wkr * xi + wki * xr; + a[j] -= yr; + a[j + 1] -= yi; + a[k] += yr; + a[k + 1] -= yi; + } +} + + +void rftbsub(int n, double *a, int nc, double *c) { + int j, k, kk, ks, m; + double wkr, wki, xr, xi, yr, yi; + + m = n >> 1; + ks = 2 * nc / m; + kk = 0; + for (j = 2; j < m; j += 2) { + k = n - j; + kk += ks; + wkr = 0.5 - c[nc - kk]; + wki = c[kk]; + xr = a[j] - a[k]; + xi = a[j + 1] + a[k + 1]; + yr = wkr * xr + wki * xi; + yi = wkr * xi - wki * xr; + a[j] -= yr; + a[j + 1] -= yi; + a[k] += yr; + a[k + 1] -= yi; + } +} + + +void dctsub(int n, double *a, int nc, double *c) { + int j, k, kk, ks, m; + double wkr, wki, xr; + + m = n >> 1; + ks = nc / n; + kk = 0; + for (j = 1; j < m; j++) { + k = n - j; + kk += ks; + wkr = c[kk] - c[nc - kk]; + wki = c[kk] + c[nc - kk]; + xr = wki * a[j] - wkr * a[k]; + a[j] = wkr * a[j] + wki * a[k]; + a[k] = xr; + } + a[m] *= c[0]; +} + + +void dstsub(int n, double *a, int nc, double *c) { + int j, k, kk, ks, m; + double wkr, wki, xr; + + m = n >> 1; + ks = nc / n; + kk = 0; + for (j = 1; j < m; j++) { + k = n - j; + kk += ks; + wkr = c[kk] - c[nc - kk]; + wki = c[kk] + c[nc - kk]; + xr = wki * a[k] - wkr * a[j]; + a[k] = wkr * a[k] + wki * a[j]; + a[j] = xr; + } + a[m] *= c[0]; +} diff --git a/speechx/speechx/frontend/audio/frontend_itf.h b/runtime/engine/common/frontend/frontend_itf.h similarity index 88% rename from speechx/speechx/frontend/audio/frontend_itf.h rename to runtime/engine/common/frontend/frontend_itf.h index 7913cc7c..57186ec4 100644 --- a/speechx/speechx/frontend/audio/frontend_itf.h +++ b/runtime/engine/common/frontend/frontend_itf.h @@ -15,20 +15,20 @@ #pragma once #include "base/basic_types.h" -#include "kaldi/matrix/kaldi-vector.h" +#include "matrix/kaldi-vector.h" namespace ppspeech { class FrontendInterface { public: // Feed inputs: features(2D saved in 1D) or waveforms(1D). - virtual void Accept(const kaldi::VectorBase& inputs) = 0; + virtual void Accept(const std::vector& inputs) = 0; // Fetch processed data: features or waveforms. // For features(2D saved in 1D), the Matrix is squashed into Vector, // the length of output = feature_row * feature_dim. // For waveforms(1D), samples saved in vector. - virtual bool Read(kaldi::Vector* outputs) = 0; + virtual bool Read(std::vector* outputs) = 0; // Dim is the feature dim. For waveforms(1D), Dim is zero; else is specific, // e.g 80 for fbank. diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.cc b/runtime/engine/common/frontend/linear_spectrogram.cc similarity index 100% rename from speechx/speechx/frontend/audio/linear_spectrogram.cc rename to runtime/engine/common/frontend/linear_spectrogram.cc diff --git a/speechx/speechx/frontend/audio/linear_spectrogram.h b/runtime/engine/common/frontend/linear_spectrogram.h similarity index 100% rename from speechx/speechx/frontend/audio/linear_spectrogram.h rename to runtime/engine/common/frontend/linear_spectrogram.h diff --git a/runtime/engine/common/frontend/mel-computations.cc b/runtime/engine/common/frontend/mel-computations.cc new file mode 100644 index 00000000..3998af22 --- /dev/null +++ b/runtime/engine/common/frontend/mel-computations.cc @@ -0,0 +1,277 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied/modified from kaldi/src/feat/mel-computations.cc + +#include "frontend/mel-computations.h" + +#include +#include + +#include "frontend/feature-window.h" + +namespace knf { + +std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts) { + os << opts.ToString(); + return os; +} + +float MelBanks::VtlnWarpFreq( + float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN. + float vtln_high_cutoff, + float low_freq, // upper+lower frequency cutoffs in mel computation + float high_freq, + float vtln_warp_factor, + float freq) { + /// This computes a VTLN warping function that is not the same as HTK's one, + /// but has similar inputs (this function has the advantage of never + /// producing + /// empty bins). + + /// This function computes a warp function F(freq), defined between low_freq + /// and high_freq inclusive, with the following properties: + /// F(low_freq) == low_freq + /// F(high_freq) == high_freq + /// The function is continuous and piecewise linear with two inflection + /// points. + /// The lower inflection point (measured in terms of the unwarped + /// frequency) is at frequency l, determined as described below. + /// The higher inflection point is at a frequency h, determined as + /// described below. + /// If l <= f <= h, then F(f) = f/vtln_warp_factor. + /// If the higher inflection point (measured in terms of the unwarped + /// frequency) is at h, then max(h, F(h)) == vtln_high_cutoff. + /// Since (by the last point) F(h) == h/vtln_warp_factor, then + /// max(h, h/vtln_warp_factor) == vtln_high_cutoff, so + /// h = vtln_high_cutoff / max(1, 1/vtln_warp_factor). + /// = vtln_high_cutoff * min(1, vtln_warp_factor). + /// If the lower inflection point (measured in terms of the unwarped + /// frequency) is at l, then min(l, F(l)) == vtln_low_cutoff + /// This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor) + /// = vtln_low_cutoff * max(1, vtln_warp_factor) + + if (freq < low_freq || freq > high_freq) + return freq; // in case this gets called + // for out-of-range frequencies, just return the freq. + + CHECK_GT(vtln_low_cutoff, low_freq); + CHECK_LT(vtln_high_cutoff, high_freq); + + float one = 1.0f; + float l = vtln_low_cutoff * std::max(one, vtln_warp_factor); + float h = vtln_high_cutoff * std::min(one, vtln_warp_factor); + float scale = 1.0f / vtln_warp_factor; + float Fl = scale * l; // F(l); + float Fh = scale * h; // F(h); + CHECK(l > low_freq && h < high_freq); + // slope of left part of the 3-piece linear function + float scale_left = (Fl - low_freq) / (l - low_freq); + // [slope of center part is just "scale"] + + // slope of right part of the 3-piece linear function + float scale_right = (high_freq - Fh) / (high_freq - h); + + if (freq < l) { + return low_freq + scale_left * (freq - low_freq); + } else if (freq < h) { + return scale * freq; + } else { // freq >= h + return high_freq + scale_right * (freq - high_freq); + } +} + +float MelBanks::VtlnWarpMelFreq( + float vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN. + float vtln_high_cutoff, + float low_freq, // upper+lower frequency cutoffs in mel computation + float high_freq, + float vtln_warp_factor, + float mel_freq) { + return MelScale(VtlnWarpFreq(vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + vtln_warp_factor, + InverseMelScale(mel_freq))); +} + +MelBanks::MelBanks(const MelBanksOptions &opts, + const FrameExtractionOptions &frame_opts, + float vtln_warp_factor) + : htk_mode_(opts.htk_mode) { + int32_t num_bins = opts.num_bins; + if (num_bins < 3) LOG(FATAL) << "Must have at least 3 mel bins"; + + float sample_freq = frame_opts.samp_freq; + int32_t window_length_padded = frame_opts.PaddedWindowSize(); + CHECK_EQ(window_length_padded % 2, 0); + + int32_t num_fft_bins = window_length_padded / 2; + float nyquist = 0.5f * sample_freq; + + float low_freq = opts.low_freq, high_freq; + if (opts.high_freq > 0.0f) + high_freq = opts.high_freq; + else + high_freq = nyquist + opts.high_freq; + + if (low_freq < 0.0f || low_freq >= nyquist || high_freq <= 0.0f || + high_freq > nyquist || high_freq <= low_freq) { + LOG(FATAL) << "Bad values in options: low-freq " << low_freq + << " and high-freq " << high_freq << " vs. nyquist " + << nyquist; + } + + float fft_bin_width = sample_freq / window_length_padded; + // fft-bin width [think of it as Nyquist-freq / half-window-length] + + float mel_low_freq = MelScale(low_freq); + float mel_high_freq = MelScale(high_freq); + + debug_ = opts.debug_mel; + + // divide by num_bins+1 in next line because of end-effects where the bins + // spread out to the sides. + float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1); + + float vtln_low = opts.vtln_low, vtln_high = opts.vtln_high; + if (vtln_high < 0.0f) { + vtln_high += nyquist; + } + + if (vtln_warp_factor != 1.0f && + (vtln_low < 0.0f || vtln_low <= low_freq || vtln_low >= high_freq || + vtln_high <= 0.0f || vtln_high >= high_freq || + vtln_high <= vtln_low)) { + LOG(FATAL) << "Bad values in options: vtln-low " << vtln_low + << " and vtln-high " << vtln_high << ", versus " + << "low-freq " << low_freq << " and high-freq " << high_freq; + } + + bins_.resize(num_bins); + center_freqs_.resize(num_bins); + + for (int32_t bin = 0; bin < num_bins; ++bin) { + float left_mel = mel_low_freq + bin * mel_freq_delta, + center_mel = mel_low_freq + (bin + 1) * mel_freq_delta, + right_mel = mel_low_freq + (bin + 2) * mel_freq_delta; + + if (vtln_warp_factor != 1.0f) { + left_mel = VtlnWarpMelFreq(vtln_low, + vtln_high, + low_freq, + high_freq, + vtln_warp_factor, + left_mel); + center_mel = VtlnWarpMelFreq(vtln_low, + vtln_high, + low_freq, + high_freq, + vtln_warp_factor, + center_mel); + right_mel = VtlnWarpMelFreq(vtln_low, + vtln_high, + low_freq, + high_freq, + vtln_warp_factor, + right_mel); + } + center_freqs_[bin] = InverseMelScale(center_mel); + + // this_bin will be a vector of coefficients that is only + // nonzero where this mel bin is active. + std::vector this_bin(num_fft_bins); + + int32_t first_index = -1, last_index = -1; + for (int32_t i = 0; i < num_fft_bins; ++i) { + float freq = (fft_bin_width * i); // Center frequency of this fft + // bin. + float mel = MelScale(freq); + if (mel > left_mel && mel < right_mel) { + float weight; + if (mel <= center_mel) + weight = (mel - left_mel) / (center_mel - left_mel); + else + weight = (right_mel - mel) / (right_mel - center_mel); + this_bin[i] = weight; + if (first_index == -1) first_index = i; + last_index = i; + } + } + CHECK(first_index != -1 && last_index >= first_index && + "You may have set num_mel_bins too large."); + + bins_[bin].first = first_index; + int32_t size = last_index + 1 - first_index; + bins_[bin].second.insert(bins_[bin].second.end(), + this_bin.begin() + first_index, + this_bin.begin() + first_index + size); + + // Replicate a bug in HTK, for testing purposes. + if (opts.htk_mode && bin == 0 && mel_low_freq != 0.0f) { + bins_[bin].second[0] = 0.0; + } + } // for (int32_t bin = 0; bin < num_bins; ++bin) { + + if (debug_) { + std::ostringstream os; + for (size_t i = 0; i < bins_.size(); i++) { + os << "bin " << i << ", offset = " << bins_[i].first << ", vec = "; + for (auto k : bins_[i].second) os << k << ", "; + os << "\n"; + } + LOG(INFO) << os.str(); + } +} + +// "power_spectrum" contains fft energies. +void MelBanks::Compute(const float *power_spectrum, + float *mel_energies_out) const { + int32_t num_bins = bins_.size(); + + for (int32_t i = 0; i < num_bins; i++) { + int32_t offset = bins_[i].first; + const auto &v = bins_[i].second; + float energy = 0; + for (int32_t k = 0; k != v.size(); ++k) { + energy += v[k] * power_spectrum[k + offset]; + } + + // HTK-like flooring- for testing purposes (we prefer dither) + if (htk_mode_ && energy < 1.0) { + energy = 1.0; + } + + mel_energies_out[i] = energy; + + // The following assert was added due to a problem with OpenBlas that + // we had at one point (it was a bug in that library). Just to detect + // it early. + CHECK_EQ(energy, energy); // check that energy is not nan + } + + if (debug_) { + fprintf(stderr, "MEL BANKS:\n"); + for (int32_t i = 0; i < num_bins; i++) + fprintf(stderr, " %f", mel_energies_out[i]); + fprintf(stderr, "\n"); + } +} + +} // namespace knf diff --git a/runtime/engine/common/frontend/mel-computations.h b/runtime/engine/common/frontend/mel-computations.h new file mode 100644 index 00000000..2f9938bc --- /dev/null +++ b/runtime/engine/common/frontend/mel-computations.h @@ -0,0 +1,120 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// This file is copied/modified from kaldi/src/feat/mel-computations.h +#ifndef KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_ +#define KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_ + +#include +#include + +#include "frontend/feature-window.h" + +namespace knf { + +struct MelBanksOptions { + int32_t num_bins = 25; // e.g. 25; number of triangular bins + float low_freq = 20; // e.g. 20; lower frequency cutoff + + // an upper frequency cutoff; 0 -> no cutoff, negative + // ->added to the Nyquist frequency to get the cutoff. + float high_freq = 0; + + float vtln_low = 100; // vtln lower cutoff of warping function. + + // vtln upper cutoff of warping function: if negative, added + // to the Nyquist frequency to get the cutoff. + float vtln_high = -500; + + bool debug_mel = false; + // htk_mode is a "hidden" config, it does not show up on command line. + // Enables more exact compatibility with HTK, for testing purposes. Affects + // mel-energy flooring and reproduces a bug in HTK. + bool htk_mode = false; + + std::string ToString() const { + std::ostringstream os; + os << "num_bins: " << num_bins << "\n"; + os << "low_freq: " << low_freq << "\n"; + os << "high_freq: " << high_freq << "\n"; + os << "vtln_low: " << vtln_low << "\n"; + os << "vtln_high: " << vtln_high << "\n"; + os << "debug_mel: " << debug_mel << "\n"; + os << "htk_mode: " << htk_mode << "\n"; + return os.str(); + } +}; + +std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts); + +class MelBanks { + public: + static inline float InverseMelScale(float mel_freq) { + return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f); + } + + static inline float MelScale(float freq) { + return 1127.0f * logf(1.0f + freq / 700.0f); + } + + static float VtlnWarpFreq( + float vtln_low_cutoff, + float vtln_high_cutoff, // discontinuities in warp func + float low_freq, + float high_freq, // upper+lower frequency cutoffs in + // the mel computation + float vtln_warp_factor, + float freq); + + static float VtlnWarpMelFreq(float vtln_low_cutoff, + float vtln_high_cutoff, + float low_freq, + float high_freq, + float vtln_warp_factor, + float mel_freq); + + // TODO(fangjun): Remove vtln_warp_factor + MelBanks(const MelBanksOptions &opts, + const FrameExtractionOptions &frame_opts, + float vtln_warp_factor); + + /// Compute Mel energies (note: not log energies). + /// At input, "fft_energies" contains the FFT energies (not log). + /// + /// @param fft_energies 1-D array of size num_fft_bins/2+1 + /// @param mel_energies_out 1-D array of size num_mel_bins + void Compute(const float *fft_energies, float *mel_energies_out) const; + + int32_t NumBins() const { return bins_.size(); } + + private: + // center frequencies of bins, numbered from 0 ... num_bins-1. + // Needed by GetCenterFreqs(). + std::vector center_freqs_; + + // the "bins_" vector is a vector, one for each bin, of a pair: + // (the first nonzero fft-bin), (the vector of weights). + std::vector>> bins_; + + // TODO(fangjun): Remove debug_ and htk_mode_ + bool debug_; + bool htk_mode_; +}; + +} // namespace knf + +#endif // KALDI_NATIVE_FBANK_CSRC_MEL_COMPUTATIONS_H_ diff --git a/speechx/speechx/frontend/audio/normalizer.h b/runtime/engine/common/frontend/normalizer.h similarity index 90% rename from speechx/speechx/frontend/audio/normalizer.h rename to runtime/engine/common/frontend/normalizer.h index dcf721dd..5a6ca573 100644 --- a/speechx/speechx/frontend/audio/normalizer.h +++ b/runtime/engine/common/frontend/normalizer.h @@ -14,5 +14,4 @@ #pragma once -#include "frontend/audio/cmvn.h" -#include "frontend/audio/db_norm.h" \ No newline at end of file +#include "frontend/cmvn.h" \ No newline at end of file diff --git a/runtime/engine/common/frontend/rfft.cc b/runtime/engine/common/frontend/rfft.cc new file mode 100644 index 00000000..9ce6a172 --- /dev/null +++ b/runtime/engine/common/frontend/rfft.cc @@ -0,0 +1,67 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/rfft.h" + +#include +#include +#include + +#include "base/log.h" + +// see fftsg.c +#ifdef __cplusplus +extern "C" void rdft(int n, int isgn, double *a, int *ip, double *w); +#else +void rdft(int n, int isgn, double *a, int *ip, double *w); +#endif + +namespace knf { +class Rfft::RfftImpl { + public: + explicit RfftImpl(int32_t n) : n_(n), ip_(2 + std::sqrt(n / 2)), w_(n / 2) { + CHECK_EQ(n & (n - 1), 0); + } + + void Compute(float *in_out) { + std::vector d(in_out, in_out + n_); + + Compute(d.data()); + + std::copy(d.begin(), d.end(), in_out); + } + + void Compute(double *in_out) { + // 1 means forward fft + rdft(n_, 1, in_out, ip_.data(), w_.data()); + } + + private: + int32_t n_; + std::vector ip_; + std::vector w_; +}; + +Rfft::Rfft(int32_t n) : impl_(std::make_unique(n)) {} + +Rfft::~Rfft() = default; + +void Rfft::Compute(float *in_out) { impl_->Compute(in_out); } +void Rfft::Compute(double *in_out) { impl_->Compute(in_out); } + +} // namespace knf diff --git a/runtime/engine/common/frontend/rfft.h b/runtime/engine/common/frontend/rfft.h new file mode 100644 index 00000000..52da2626 --- /dev/null +++ b/runtime/engine/common/frontend/rfft.h @@ -0,0 +1,56 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef KALDI_NATIVE_FBANK_CSRC_RFFT_H_ +#define KALDI_NATIVE_FBANK_CSRC_RFFT_H_ + +#include + +namespace knf { + +// n-point Real discrete Fourier transform +// where n is a power of 2. n >= 2 +// +// R[k] = sum_j=0^n-1 in[j]*cos(2*pi*j*k/n), 0<=k<=n/2 +// I[k] = sum_j=0^n-1 in[j]*sin(2*pi*j*k/n), 0 impl_; +}; + +} // namespace knf + +#endif // KALDI_NATIVE_FBANK_CSRC_RFFT_H_ diff --git a/runtime/engine/common/frontend/wave-reader.cc b/runtime/engine/common/frontend/wave-reader.cc new file mode 100644 index 00000000..e94aafef --- /dev/null +++ b/runtime/engine/common/frontend/wave-reader.cc @@ -0,0 +1,376 @@ +// feat/wave-reader.cc + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek +// 2013 Florent Masson +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "frontend/wave-reader.h" + +#include +#include +#include +#include +#include + +#include "base/kaldi-error.h" +#include "base/kaldi-utils.h" + +namespace kaldi { + +// A utility class for reading wave header. +struct WaveHeaderReadGofer { + std::istream &is; + bool swap; + char tag[5]; + + WaveHeaderReadGofer(std::istream &is) : is(is), swap(false) { + memset(tag, '\0', sizeof tag); + } + + void Expect4ByteTag(const char *expected) { + is.read(tag, 4); + if (is.fail()) + KALDI_ERR << "WaveData: expected " << expected + << ", failed to read anything"; + if (strcmp(tag, expected)) + KALDI_ERR << "WaveData: expected " << expected << ", got " << tag; + } + + void Read4ByteTag() { + is.read(tag, 4); + if (is.fail()) + KALDI_ERR << "WaveData: expected 4-byte chunk-name, got read error"; + } + + uint32 ReadUint32() { + union { + char result[4]; + uint32 ans; + } u; + is.read(u.result, 4); + if (swap) KALDI_SWAP4(u.result); + if (is.fail()) + KALDI_ERR << "WaveData: unexpected end of file or read error"; + return u.ans; + } + + uint16 ReadUint16() { + union { + char result[2]; + int16 ans; + } u; + is.read(u.result, 2); + if (swap) KALDI_SWAP2(u.result); + if (is.fail()) + KALDI_ERR << "WaveData: unexpected end of file or read error"; + return u.ans; + } +}; + +static void WriteUint32(std::ostream &os, int32 i) { + union { + char buf[4]; + int i; + } u; + u.i = i; +#ifdef __BIG_ENDIAN__ + KALDI_SWAP4(u.buf); +#endif + os.write(u.buf, 4); + if (os.fail()) KALDI_ERR << "WaveData: error writing to stream."; +} + +static void WriteUint16(std::ostream &os, int16 i) { + union { + char buf[2]; + int16 i; + } u; + u.i = i; +#ifdef __BIG_ENDIAN__ + KALDI_SWAP2(u.buf); +#endif + os.write(u.buf, 2); + if (os.fail()) KALDI_ERR << "WaveData: error writing to stream."; +} + +void WaveInfo::Read(std::istream &is) { + WaveHeaderReadGofer reader(is); + reader.Read4ByteTag(); + if (strcmp(reader.tag, "RIFF") == 0) + reverse_bytes_ = false; + else if (strcmp(reader.tag, "RIFX") == 0) + reverse_bytes_ = true; + else + KALDI_ERR << "WaveData: expected RIFF or RIFX, got " << reader.tag; + +#ifdef __BIG_ENDIAN__ + reverse_bytes_ = !reverse_bytes_; +#endif + reader.swap = reverse_bytes_; + + uint32 riff_chunk_size = reader.ReadUint32(); + reader.Expect4ByteTag("WAVE"); + + uint32 riff_chunk_read = 0; + riff_chunk_read += 4; // WAVE included in riff_chunk_size. + + // Possibly skip any RIFF tags between 'WAVE' and 'fmt '. + // Apple devices produce a filler tag 'JUNK' for memory alignment. + reader.Read4ByteTag(); + riff_chunk_read += 4; + while (strcmp(reader.tag, "fmt ") != 0) { + uint32 filler_size = reader.ReadUint32(); + riff_chunk_read += 4; + for (uint32 i = 0; i < filler_size; i++) { + is.get(); // read 1 byte, + } + riff_chunk_read += filler_size; + // get next RIFF tag, + reader.Read4ByteTag(); + riff_chunk_read += 4; + } + + KALDI_ASSERT(strcmp(reader.tag, "fmt ") == 0); + uint32 subchunk1_size = reader.ReadUint32(); + uint16 audio_format = reader.ReadUint16(); + num_channels_ = reader.ReadUint16(); + uint32 sample_rate = reader.ReadUint32(), byte_rate = reader.ReadUint32(), + block_align = reader.ReadUint16(), + bits_per_sample = reader.ReadUint16(); + samp_freq_ = static_cast(sample_rate); + + uint32 fmt_chunk_read = 16; + if (audio_format == 1) { + if (subchunk1_size < 16) { + KALDI_ERR << "WaveData: expect PCM format data to have fmt chunk " + << "of at least size 16."; + } + } else if (audio_format == 0xFFFE) { // WAVE_FORMAT_EXTENSIBLE + uint16 extra_size = reader.ReadUint16(); + if (subchunk1_size < 40 || extra_size < 22) { + KALDI_ERR + << "WaveData: malformed WAVE_FORMAT_EXTENSIBLE format data."; + } + reader.ReadUint16(); // Unused for PCM. + reader.ReadUint32(); // Channel map: we do not care. + uint32 guid1 = reader.ReadUint32(), guid2 = reader.ReadUint32(), + guid3 = reader.ReadUint32(), guid4 = reader.ReadUint32(); + fmt_chunk_read = 40; + + // Support only KSDATAFORMAT_SUBTYPE_PCM for now. Interesting formats: + // ("00000001-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_PCM) + // ("00000003-0000-0010-8000-00aa00389b71", + // KSDATAFORMAT_SUBTYPE_IEEE_FLOAT) + // ("00000006-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_ALAW) + // ("00000007-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_MULAW) + if (guid1 != 0x00000001 || guid2 != 0x00100000 || guid3 != 0xAA000080 || + guid4 != 0x719B3800) { + KALDI_ERR << "WaveData: unsupported WAVE_FORMAT_EXTENSIBLE format."; + } + } else { + KALDI_ERR << "WaveData: can read only PCM data, format id in file is: " + << audio_format; + } + + for (uint32 i = fmt_chunk_read; i < subchunk1_size; ++i) + is.get(); // use up extra data. + + if (num_channels_ == 0) KALDI_ERR << "WaveData: no channels present"; + if (bits_per_sample != 16) + KALDI_ERR << "WaveData: unsupported bits_per_sample = " + << bits_per_sample; + if (byte_rate != sample_rate * bits_per_sample / 8 * num_channels_) + KALDI_ERR << "Unexpected byte rate " << byte_rate << " vs. " + << sample_rate << " * " << (bits_per_sample / 8) << " * " + << num_channels_; + if (block_align != num_channels_ * bits_per_sample / 8) + KALDI_ERR << "Unexpected block_align: " << block_align << " vs. " + << num_channels_ << " * " << (bits_per_sample / 8); + + riff_chunk_read += 4 + subchunk1_size; + // size of what we just read, 4 for subchunk1_size + subchunk1_size itself. + + // We support an optional "fact" chunk (which is useless but which + // we encountered), and then a single "data" chunk. + + reader.Read4ByteTag(); + riff_chunk_read += 4; + + // Skip any subchunks between "fmt" and "data". Usually there will + // be a single "fact" subchunk, but on Windows there can also be a + // "list" subchunk. + while (strcmp(reader.tag, "data") != 0) { + // We will just ignore the data in these chunks. + uint32 chunk_sz = reader.ReadUint32(); + if (chunk_sz != 4 && strcmp(reader.tag, "fact") == 0) + KALDI_WARN << "Expected fact chunk to be 4 bytes long."; + for (uint32 i = 0; i < chunk_sz; i++) is.get(); + riff_chunk_read += + 4 + chunk_sz; // for chunk_sz (4) + chunk contents (chunk-sz) + + // Now read the next chunk name. + reader.Read4ByteTag(); + riff_chunk_read += 4; + } + + KALDI_ASSERT(strcmp(reader.tag, "data") == 0); + uint32 data_chunk_size = reader.ReadUint32(); + riff_chunk_read += 4; + + // Figure out if the file is going to be read to the end. Values as + // observed in the wild: + bool is_stream_mode = + riff_chunk_size == 0 || riff_chunk_size == 0xFFFFFFFF || + data_chunk_size == 0 || data_chunk_size == 0xFFFFFFFF || + data_chunk_size == 0x7FFFF000; // This value is used by SoX. + + if (is_stream_mode) + KALDI_VLOG(1) << "Read in RIFF chunk size: " << riff_chunk_size + << ", data chunk size: " << data_chunk_size + << ". Assume 'stream mode' (reading data to EOF)."; + + if (!is_stream_mode && std::abs(static_cast(riff_chunk_read) + + static_cast(data_chunk_size) - + static_cast(riff_chunk_size)) > 1) { + // We allow the size to be off by one without warning, because there is + // a + // weirdness in the format of RIFF files that means that the input may + // sometimes be padded with 1 unused byte to make the total size even. + KALDI_WARN << "Expected " << riff_chunk_size + << " bytes in RIFF chunk, but " + << "after first data block there will be " << riff_chunk_read + << " + " << data_chunk_size << " bytes " + << "(we do not support reading multiple data chunks)."; + } + + if (is_stream_mode) + samp_count_ = -1; + else + samp_count_ = data_chunk_size / block_align; +} + +void WaveData::Read(std::istream &is) { + const uint32 kBlockSize = 1024 * 1024; + + WaveInfo header; + header.Read(is); + + data_.Resize(0, 0); // clear the data. + samp_freq_ = header.SampFreq(); + + std::vector buffer; + uint32 bytes_to_go = header.IsStreamed() ? kBlockSize : header.DataBytes(); + + // Once in a while header.DataBytes() will report an insane value; + // read the file to the end + while (is && bytes_to_go > 0) { + uint32 block_bytes = std::min(bytes_to_go, kBlockSize); + uint32 offset = buffer.size(); + buffer.resize(offset + block_bytes); + is.read(&buffer[offset], block_bytes); + uint32 bytes_read = is.gcount(); + buffer.resize(offset + bytes_read); + if (!header.IsStreamed()) bytes_to_go -= bytes_read; + } + + if (is.bad()) KALDI_ERR << "WaveData: file read error"; + + if (buffer.size() == 0) KALDI_ERR << "WaveData: empty file (no data)"; + + if (!header.IsStreamed() && buffer.size() < header.DataBytes()) { + KALDI_WARN << "Expected " << header.DataBytes() + << " bytes of wave data, " + << "but read only " << buffer.size() << " bytes. " + << "Truncated file?"; + } + + uint16 *data_ptr = reinterpret_cast(&buffer[0]); + + // The matrix is arranged row per channel, column per sample. + data_.Resize(header.NumChannels(), buffer.size() / header.BlockAlign()); + for (uint32 i = 0; i < data_.NumCols(); ++i) { + for (uint32 j = 0; j < data_.NumRows(); ++j) { + int16 k = *data_ptr++; + if (header.ReverseBytes()) KALDI_SWAP2(k); + data_(j, i) = k; + } + } +} + + +// Write 16-bit PCM. + +// note: the WAVE chunk contains 2 subchunks. +// +// subchunk2size = data.NumRows() * data.NumCols() * 2. + + +void WaveData::Write(std::ostream &os) const { + os << "RIFF"; + if (data_.NumRows() == 0) + KALDI_ERR << "Error: attempting to write empty WAVE file"; + + int32 num_chan = data_.NumRows(), num_samp = data_.NumCols(), + bytes_per_samp = 2; + + int32 subchunk2size = (num_chan * num_samp * bytes_per_samp); + int32 chunk_size = 36 + subchunk2size; + WriteUint32(os, chunk_size); + os << "WAVE"; + os << "fmt "; + WriteUint32(os, 16); + WriteUint16(os, 1); + WriteUint16(os, num_chan); + KALDI_ASSERT(samp_freq_ > 0); + WriteUint32(os, static_cast(samp_freq_)); + WriteUint32(os, static_cast(samp_freq_) * num_chan * bytes_per_samp); + WriteUint16(os, num_chan * bytes_per_samp); + WriteUint16(os, 8 * bytes_per_samp); + os << "data"; + WriteUint32(os, subchunk2size); + + const BaseFloat *data_ptr = data_.Data(); + int32 stride = data_.Stride(); + + int num_clipped = 0; + for (int32 i = 0; i < num_samp; i++) { + for (int32 j = 0; j < num_chan; j++) { + int32 elem = static_cast(trunc(data_ptr[j * stride + i])); + int16 elem_16 = static_cast(elem); + if (elem < std::numeric_limits::min()) { + elem_16 = std::numeric_limits::min(); + ++num_clipped; + } else if (elem > std::numeric_limits::max()) { + elem_16 = std::numeric_limits::max(); + ++num_clipped; + } +#ifdef __BIG_ENDIAN__ + KALDI_SWAP2(elem_16); +#endif + os.write(reinterpret_cast(&elem_16), 2); + } + } + if (os.fail()) KALDI_ERR << "Error writing wave data to stream."; + if (num_clipped > 0) + KALDI_WARN << "WARNING: clipped " << num_clipped + << " samples out of total " << num_chan * num_samp + << ". Reduce volume?"; +} + + +} // end namespace kaldi diff --git a/runtime/engine/common/frontend/wave-reader.h b/runtime/engine/common/frontend/wave-reader.h new file mode 100644 index 00000000..6cd471b8 --- /dev/null +++ b/runtime/engine/common/frontend/wave-reader.h @@ -0,0 +1,248 @@ +// feat/wave-reader.h + +// Copyright 2009-2011 Karel Vesely; Microsoft Corporation +// 2013 Florent Masson +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +/* +// THE WAVE FORMAT IS SPECIFIED IN: +// https:// ccrma.stanford.edu/courses/422/projects/WaveFormat/ +// +// +// +// RIFF +// | +// WAVE +// | \ \ \ +// fmt_ data ... data +// +// +// Riff is a general container, which usually contains one WAVE chunk +// each WAVE chunk has header sub-chunk 'fmt_' +// and one or more data sub-chunks 'data' +// +// [Note from Dan: to say that the wave format was ever "specified" anywhere is +// not quite right. The guy who invented the wave format attempted to create +// a formal specification but it did not completely make sense. And there +// doesn't seem to be a consensus on what makes a valid wave file, +// particularly where the accuracy of header information is concerned.] +*/ + + +#ifndef KALDI_FEAT_WAVE_READER_H_ +#define KALDI_FEAT_WAVE_READER_H_ + +#include + +#include "base/kaldi-types.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/kaldi-vector.h" + + +namespace kaldi { + +/// For historical reasons, we scale waveforms to the range +/// (2^15-1)*[-1, 1], not the usual default DSP range [-1, 1]. +const BaseFloat kWaveSampleMax = 32768.0; + +/// This class reads and hold wave file header information. +class WaveInfo { + public: + WaveInfo() + : samp_freq_(0), samp_count_(0), num_channels_(0), reverse_bytes_(0) {} + + /// Is stream size unknown? Duration and SampleCount not valid if true. + bool IsStreamed() const { return samp_count_ < 0; } + + /// Sample frequency, Hz. + BaseFloat SampFreq() const { return samp_freq_; } + + /// Number of samples in stream. Invalid if IsStreamed() is true. + uint32 SampleCount() const { return samp_count_; } + + /// Approximate duration, seconds. Invalid if IsStreamed() is true. + BaseFloat Duration() const { return samp_count_ / samp_freq_; } + + /// Number of channels, 1 to 16. + int32 NumChannels() const { return num_channels_; } + + /// Bytes per sample. + size_t BlockAlign() const { return 2 * num_channels_; } + + /// Wave data bytes. Invalid if IsStreamed() is true. + size_t DataBytes() const { return samp_count_ * BlockAlign(); } + + /// Is data file byte order different from machine byte order? + bool ReverseBytes() const { return reverse_bytes_; } + + /// 'is' should be opened in binary mode. Read() will throw on error. + /// On success 'is' will be positioned at the beginning of wave data. + void Read(std::istream &is); + + private: + BaseFloat samp_freq_; + int32 samp_count_; // 0 if empty, -1 if undefined length. + uint8 num_channels_; + bool reverse_bytes_; // File endianness differs from host. +}; + +/// This class's purpose is to read in Wave files. +class WaveData { + public: + WaveData(BaseFloat samp_freq, const MatrixBase &data) + : data_(data), samp_freq_(samp_freq) {} + + WaveData() : samp_freq_(0.0) {} + + /// Read() will throw on error. It's valid to call Read() more than once-- + /// in this case it will destroy what was there before. + /// "is" should be opened in binary mode. + void Read(std::istream &is); + + /// Write() will throw on error. os should be opened in binary mode. + void Write(std::ostream &os) const; + + // This function returns the wave data-- it's in a matrix + // because there may be multiple channels. In the normal case + // there's just one channel so Data() will have one row. + const Matrix &Data() const { return data_; } + + BaseFloat SampFreq() const { return samp_freq_; } + + // Returns the duration in seconds + BaseFloat Duration() const { return data_.NumCols() / samp_freq_; } + + void CopyFrom(const WaveData &other) { + samp_freq_ = other.samp_freq_; + data_.CopyFromMat(other.data_); + } + + void Clear() { + data_.Resize(0, 0); + samp_freq_ = 0.0; + } + + void Swap(WaveData *other) { + data_.Swap(&(other->data_)); + std::swap(samp_freq_, other->samp_freq_); + } + + private: + static const uint32 kBlockSize = 1024 * 1024; // Use 1M bytes. + Matrix data_; + BaseFloat samp_freq_; +}; + + +// Holder class for .wav files that enables us to read (but not write) .wav +// files. c.f. util/kaldi-holder.h we don't use the KaldiObjectHolder template +// because we don't want to check for the \0B binary header. We could have faked +// it by pretending to read in the wave data in text mode after failing to find +// the \0B header, but that would have been a little ugly. +class WaveHolder { + public: + typedef WaveData T; + + static bool Write(std::ostream &os, bool binary, const T &t) { + // We don't write the binary-mode header here [always binary]. + if (!binary) + KALDI_ERR << "Wave data can only be written in binary mode."; + try { + t.Write(os); // throws exception on failure. + return true; + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught in WaveHolder object (writing). " + << e.what(); + return false; // write failure. + } + } + void Copy(const T &t) { t_.CopyFrom(t); } + + static bool IsReadInBinary() { return true; } + + void Clear() { t_.Clear(); } + + T &Value() { return t_; } + + WaveHolder &operator=(const WaveHolder &other) { + t_.CopyFrom(other.t_); + return *this; + } + WaveHolder(const WaveHolder &other) : t_(other.t_) {} + + WaveHolder() {} + + bool Read(std::istream &is) { + // We don't look for the binary-mode header here [always binary] + try { + t_.Read(is); // Throws exception on failure. + return true; + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught in WaveHolder::Read(). " + << e.what(); + return false; + } + } + + void Swap(WaveHolder *other) { t_.Swap(&(other->t_)); } + + bool ExtractRange(const WaveHolder &other, const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + private: + T t_; +}; + +// This is like WaveHolder but when you just want the metadata- +// it leaves the actual data undefined, it doesn't read it. +class WaveInfoHolder { + public: + typedef WaveInfo T; + + void Clear() { info_ = WaveInfo(); } + void Swap(WaveInfoHolder *other) { std::swap(info_, other->info_); } + T &Value() { return info_; } + static bool IsReadInBinary() { return true; } + + bool Read(std::istream &is) { + try { + info_.Read(is); // Throws exception on failure. + return true; + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught in WaveInfoHolder::Read(). " + << e.what(); + return false; + } + } + + bool ExtractRange(const WaveInfoHolder &other, const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + private: + WaveInfo info_; +}; + + +} // namespace kaldi + +#endif // KALDI_FEAT_WAVE_READER_H_ diff --git a/runtime/engine/common/matrix/CMakeLists.txt b/runtime/engine/common/matrix/CMakeLists.txt new file mode 100644 index 00000000..a4b34d54 --- /dev/null +++ b/runtime/engine/common/matrix/CMakeLists.txt @@ -0,0 +1,7 @@ + +add_library(kaldi-matrix +kaldi-matrix.cc +kaldi-vector.cc +) + +target_link_libraries(kaldi-matrix kaldi-base) diff --git a/speechx/speechx/kaldi/matrix/kaldi-matrix-inl.h b/runtime/engine/common/matrix/kaldi-matrix-inl.h similarity index 65% rename from speechx/speechx/kaldi/matrix/kaldi-matrix-inl.h rename to runtime/engine/common/matrix/kaldi-matrix-inl.h index c2ff0079..ed18859d 100644 --- a/speechx/speechx/kaldi/matrix/kaldi-matrix-inl.h +++ b/runtime/engine/common/matrix/kaldi-matrix-inl.h @@ -25,39 +25,41 @@ namespace kaldi { /// Empty constructor -template -Matrix::Matrix(): MatrixBase(NULL, 0, 0, 0) { } - +template +Matrix::Matrix() : MatrixBase(NULL, 0, 0, 0) {} +/* template<> template<> -void MatrixBase::AddVecVec(const float alpha, const VectorBase &ra, const VectorBase &rb); +void MatrixBase::AddVecVec(const float alpha, const VectorBase +&ra, const VectorBase &rb); template<> template<> -void MatrixBase::AddVecVec(const double alpha, const VectorBase &ra, const VectorBase &rb); - -template -inline std::ostream & operator << (std::ostream & os, const MatrixBase & M) { - M.Write(os, false); - return os; +void MatrixBase::AddVecVec(const double alpha, const VectorBase +&ra, const VectorBase &rb); +*/ + +template +inline std::ostream& operator<<(std::ostream& os, const MatrixBase& M) { + M.Write(os, false); + return os; } -template -inline std::istream & operator >> (std::istream & is, Matrix & M) { - M.Read(is, false); - return is; +template +inline std::istream& operator>>(std::istream& is, Matrix& M) { + M.Read(is, false); + return is; } -template -inline std::istream & operator >> (std::istream & is, MatrixBase & M) { - M.Read(is, false); - return is; +template +inline std::istream& operator>>(std::istream& is, MatrixBase& M) { + M.Read(is, false); + return is; } -}// namespace kaldi +} // namespace kaldi #endif // KALDI_MATRIX_KALDI_MATRIX_INL_H_ - diff --git a/speechx/speechx/kaldi/matrix/kaldi-matrix.cc b/runtime/engine/common/matrix/kaldi-matrix.cc similarity index 75% rename from speechx/speechx/kaldi/matrix/kaldi-matrix.cc rename to runtime/engine/common/matrix/kaldi-matrix.cc index faf23cdf..6f65fb0a 100644 --- a/speechx/speechx/kaldi/matrix/kaldi-matrix.cc +++ b/runtime/engine/common/matrix/kaldi-matrix.cc @@ -23,17 +23,9 @@ // limitations under the License. #include "matrix/kaldi-matrix.h" -#include "matrix/sp-matrix.h" -#include "matrix/jama-svd.h" -#include "matrix/jama-eig.h" -#include "matrix/compressed-matrix.h" -#include "matrix/sparse-matrix.h" - -static_assert(int(kaldi::kNoTrans) == int(CblasNoTrans) && int(kaldi::kTrans) == int(CblasTrans), - "kaldi::kNoTrans and kaldi::kTrans must be equal to the appropriate CBLAS library constants!"); namespace kaldi { - +/* template void MatrixBase::Invert(Real *log_det, Real *det_sign, bool inverse_needed) { @@ -174,14 +166,19 @@ void MatrixBase::AddMatMat(const Real alpha, const MatrixBase& B, MatrixTransposeType transB, const Real beta) { - KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kTrans && transB == kNoTrans && A.num_rows_ == B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kNoTrans && transB == kTrans && A.num_cols_ == B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) - || (transA == kTrans && transB == kTrans && A.num_rows_ == B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); + KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == +B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kTrans && transB == kNoTrans && A.num_rows_ == +B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kNoTrans && transB == kTrans && A.num_cols_ == +B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) + || (transA == kTrans && transB == kTrans && A.num_rows_ == +B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); KALDI_ASSERT(&A != this && &B != this); if (num_rows_ == 0) return; cblas_Xgemm(alpha, transA, A.data_, A.num_rows_, A.num_cols_, A.stride_, - transB, B.data_, B.stride_, beta, data_, num_rows_, num_cols_, stride_); + transB, B.data_, B.stride_, beta, data_, num_rows_, num_cols_, +stride_); } @@ -199,36 +196,38 @@ void MatrixBase::SetMatMatDivMat(const MatrixBase& A, id = od * (o / i); /// o / i is either zero or "scale". } else { id = od; /// Just imagine the scale was 1.0. This is somehow true in - /// expectation; anyway, this case should basically never happen so it doesn't + /// expectation; anyway, this case should basically never happen so it +doesn't /// really matter. } (*this)(r, c) = id; } } } +*/ +// template +// void MatrixBase::CopyLowerToUpper() { +// KALDI_ASSERT(num_rows_ == num_cols_); +// Real *data = data_; +// MatrixIndexT num_rows = num_rows_, stride = stride_; +// for (int32 i = 0; i < num_rows; i++) +// for (int32 j = 0; j < i; j++) +// data[j * stride + i ] = data[i * stride + j]; +//} -template -void MatrixBase::CopyLowerToUpper() { - KALDI_ASSERT(num_rows_ == num_cols_); - Real *data = data_; - MatrixIndexT num_rows = num_rows_, stride = stride_; - for (int32 i = 0; i < num_rows; i++) - for (int32 j = 0; j < i; j++) - data[j * stride + i ] = data[i * stride + j]; -} +// template +// void MatrixBase::CopyUpperToLower() { +// KALDI_ASSERT(num_rows_ == num_cols_); +// Real *data = data_; +// MatrixIndexT num_rows = num_rows_, stride = stride_; +// for (int32 i = 0; i < num_rows; i++) +// for (int32 j = 0; j < i; j++) +// data[i * stride + j] = data[j * stride + i]; +//} -template -void MatrixBase::CopyUpperToLower() { - KALDI_ASSERT(num_rows_ == num_cols_); - Real *data = data_; - MatrixIndexT num_rows = num_rows_, stride = stride_; - for (int32 i = 0; i < num_rows; i++) - for (int32 j = 0; j < i; j++) - data[i * stride + j] = data[j * stride + i]; -} - +/* template void MatrixBase::SymAddMat2(const Real alpha, const MatrixBase &A, @@ -270,10 +269,14 @@ void MatrixBase::AddMatSmat(const Real alpha, const MatrixBase &B, MatrixTransposeType transB, const Real beta) { - KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kTrans && transB == kNoTrans && A.num_rows_ == B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kNoTrans && transB == kTrans && A.num_cols_ == B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) - || (transA == kTrans && transB == kTrans && A.num_rows_ == B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); + KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == +B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kTrans && transB == kNoTrans && A.num_rows_ == +B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kNoTrans && transB == kTrans && A.num_cols_ == +B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) + || (transA == kTrans && transB == kTrans && A.num_rows_ == +B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); KALDI_ASSERT(&A != this && &B != this); // We iterate over the columns of B. @@ -308,10 +311,14 @@ void MatrixBase::AddSmatMat(const Real alpha, const MatrixBase &B, MatrixTransposeType transB, const Real beta) { - KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kTrans && transB == kNoTrans && A.num_rows_ == B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) - || (transA == kNoTrans && transB == kTrans && A.num_cols_ == B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) - || (transA == kTrans && transB == kTrans && A.num_rows_ == B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); + KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == +B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kTrans && transB == kNoTrans && A.num_rows_ == +B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kNoTrans && transB == kTrans && A.num_cols_ == +B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) + || (transA == kTrans && transB == kTrans && A.num_rows_ == +B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); KALDI_ASSERT(&A != this && &B != this); MatrixIndexT Astride = A.stride_, Bstride = B.stride_, stride = this->stride_, @@ -349,7 +356,8 @@ void MatrixBase::AddSpSp(const Real alpha, const SpMatrix &A_in, // fully (to save work, we used the matrix constructor from SpMatrix). // CblasLeft means A is on the left: C <-- alpha A B + beta C if (sz == 0) return; - cblas_Xsymm(alpha, sz, A.data_, A.stride_, B.data_, B.stride_, beta, data_, stride_); + cblas_Xsymm(alpha, sz, A.data_, A.stride_, B.data_, B.stride_, beta, data_, +stride_); } template @@ -359,13 +367,15 @@ void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, if (transA == kNoTrans) { Scale(alpha + 1.0); } else { - KALDI_ASSERT(num_rows_ == num_cols_ && "AddMat: adding to self (transposed): not symmetric."); + KALDI_ASSERT(num_rows_ == num_cols_ && "AddMat: adding to self +(transposed): not symmetric."); Real *data = data_; if (alpha == 1.0) { // common case-- handle separately. for (MatrixIndexT row = 0; row < num_rows_; row++) { for (MatrixIndexT col = 0; col < row; col++) { Real *lower = data + (row * stride_) + col, *upper = data + (col - * stride_) + row; + * +stride_) + row; Real sum = *lower + *upper; *lower = *upper = sum; } @@ -375,7 +385,8 @@ void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, for (MatrixIndexT row = 0; row < num_rows_; row++) { for (MatrixIndexT col = 0; col < row; col++) { Real *lower = data + (row * stride_) + col, *upper = data + (col - * stride_) + row; + * +stride_) + row; Real lower_tmp = *lower; *lower += alpha * *upper; *upper += alpha * lower_tmp; @@ -397,7 +408,8 @@ void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, } else { KALDI_ASSERT(A.num_cols_ == num_rows_ && A.num_rows_ == num_cols_); if (num_rows_ == 0) return; - for (MatrixIndexT row = 0; row < num_rows_; row++, adata++, data += stride) + for (MatrixIndexT row = 0; row < num_rows_; row++, adata++, data += +stride) cblas_Xaxpy(num_cols_, alpha, adata, aStride, data, 1); } } @@ -510,7 +522,8 @@ void MatrixBase::AddMatSmat(Real alpha, const MatrixBase &A, Real alpha_B_kj = alpha * p.second; Real *this_col_j = this->Data() + j; // Add to entire 'j'th column of *this at once using cblas_Xaxpy. - // pass stride to write a colmun as matrices are stored in row major order. + // pass stride to write a colmun as matrices are stored in row major +order. cblas_Xaxpy(this_num_rows, alpha_B_kj, a_col_k, A.stride_, this_col_j, this->stride_); //for (MatrixIndexT i = 0; i < this_num_rows; ++i) @@ -536,10 +549,11 @@ void MatrixBase::AddMatSmat(Real alpha, const MatrixBase &A, Real alpha_B_jk = alpha * p.second; const Real *a_col_k = A.Data() + k; // Add to entire 'j'th column of *this at once using cblas_Xaxpy. - // pass stride to write a column as matrices are stored in row major order. + // pass stride to write a column as matrices are stored in row major +order. cblas_Xaxpy(this_num_rows, alpha_B_jk, a_col_k, A.stride_, this_col_j, this->stride_); - //for (MatrixIndexT i = 0; i < this_num_rows; ++i) + //for (MatrixIndexT i = 0; i < this_num_rows; ++i) // this_col_j[i*this->stride_] += alpha_B_jk * a_col_k[i*A.stride_]; } } @@ -593,7 +607,8 @@ void MatrixBase::AddDiagVecMat( Real *data = data_; const Real *Mdata = M.Data(), *vdata = v.Data(); if (num_rows_ == 0) return; - for (MatrixIndexT i = 0; i < num_rows; i++, data += stride, Mdata += M_row_stride, vdata++) + for (MatrixIndexT i = 0; i < num_rows; i++, data += stride, Mdata += +M_row_stride, vdata++) cblas_Xaxpy(num_cols, alpha * *vdata, Mdata, M_col_stride, data, 1); } @@ -627,7 +642,8 @@ void MatrixBase::AddMatDiagVec( if (num_rows_ == 0) return; for (MatrixIndexT i = 0; i < num_rows; i++){ for(MatrixIndexT j = 0; j < num_cols; j ++ ){ - data[i*stride + j] += alpha * vdata[j] * Mdata[i*M_row_stride + j*M_col_stride]; + data[i*stride + j] += alpha * vdata[j] * Mdata[i*M_row_stride + +j*M_col_stride]; } } } @@ -662,8 +678,10 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, KALDI_ASSERT(s != NULL && U_in != this && V_in != this); Matrix tmpU, tmpV; - if (U_in == NULL) tmpU.Resize(this->num_rows_, 1); // work-space if U_in empty. - if (V_in == NULL) tmpV.Resize(1, this->num_cols_); // work-space if V_in empty. + if (U_in == NULL) tmpU.Resize(this->num_rows_, 1); // work-space if U_in +empty. + if (V_in == NULL) tmpV.Resize(1, this->num_cols_); // work-space if V_in +empty. /// Impementation notes: /// Lapack works in column-order, therefore the dimensions of *this are @@ -697,8 +715,10 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, KaldiBlasInt result; // query for work space - char *u_job = const_cast(U_in ? "s" : "N"); // "s" == skinny, "N" == "none." - char *v_job = const_cast(V_in ? "s" : "N"); // "s" == skinny, "N" == "none." + char *u_job = const_cast(U_in ? "s" : "N"); // "s" == skinny, "N" == +"none." + char *v_job = const_cast(V_in ? "s" : "N"); // "s" == skinny, "N" == +"none." clapack_Xgesvd(v_job, u_job, &M, &N, data_, &LDA, s->Data(), @@ -707,7 +727,8 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, &work_query, &l_work, &result); - KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong arguments"); + KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong +arguments"); l_work = static_cast(work_query); Real *p_work; @@ -725,7 +746,8 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, p_work, &l_work, &result); - KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong arguments"); + KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong +arguments"); if (result != 0) { KALDI_WARN << "CLAPACK sgesvd_ : some weird convergence not satisfied"; @@ -734,170 +756,170 @@ void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, } #endif - +*/ // Copy constructor. Copies data to newly allocated memory. -template -Matrix::Matrix (const MatrixBase & M, - MatrixTransposeType trans/*=kNoTrans*/) +template +Matrix::Matrix(const MatrixBase &M, + MatrixTransposeType trans /*=kNoTrans*/) : MatrixBase() { - if (trans == kNoTrans) { - Resize(M.num_rows_, M.num_cols_); - this->CopyFromMat(M); - } else { - Resize(M.num_cols_, M.num_rows_); - this->CopyFromMat(M, kTrans); - } + if (trans == kNoTrans) { + Resize(M.num_rows_, M.num_cols_); + this->CopyFromMat(M); + } else { + Resize(M.num_cols_, M.num_rows_); + this->CopyFromMat(M, kTrans); + } } // Copy constructor. Copies data to newly allocated memory. -template -Matrix::Matrix (const Matrix & M): - MatrixBase() { - Resize(M.num_rows_, M.num_cols_); - this->CopyFromMat(M); +template +Matrix::Matrix(const Matrix &M) : MatrixBase() { + Resize(M.num_rows_, M.num_cols_); + this->CopyFromMat(M); } /// Copy constructor from another type. -template -template -Matrix::Matrix(const MatrixBase & M, - MatrixTransposeType trans) : MatrixBase() { - if (trans == kNoTrans) { - Resize(M.NumRows(), M.NumCols()); - this->CopyFromMat(M); - } else { - Resize(M.NumCols(), M.NumRows()); - this->CopyFromMat(M, kTrans); - } +template +template +Matrix::Matrix(const MatrixBase &M, MatrixTransposeType trans) + : MatrixBase() { + if (trans == kNoTrans) { + Resize(M.NumRows(), M.NumCols()); + this->CopyFromMat(M); + } else { + Resize(M.NumCols(), M.NumRows()); + this->CopyFromMat(M, kTrans); + } } // Instantiate this constructor for float->double and double->float. -template -Matrix::Matrix(const MatrixBase & M, - MatrixTransposeType trans); -template -Matrix::Matrix(const MatrixBase & M, - MatrixTransposeType trans); +template Matrix::Matrix(const MatrixBase &M, + MatrixTransposeType trans); +template Matrix::Matrix(const MatrixBase &M, + MatrixTransposeType trans); -template +template inline void Matrix::Init(const MatrixIndexT rows, const MatrixIndexT cols, const MatrixStrideType stride_type) { - if (rows * cols == 0) { - KALDI_ASSERT(rows == 0 && cols == 0); - this->num_rows_ = 0; - this->num_cols_ = 0; - this->stride_ = 0; - this->data_ = NULL; - return; - } - KALDI_ASSERT(rows > 0 && cols > 0); - MatrixIndexT skip, stride; - size_t size; - void *data; // aligned memory block - void *temp; // memory block to be really freed - - // compute the size of skip and real cols - skip = ((16 / sizeof(Real)) - cols % (16 / sizeof(Real))) - % (16 / sizeof(Real)); - stride = cols + skip; - size = static_cast(rows) * static_cast(stride) - * sizeof(Real); - - // allocate the memory and set the right dimensions and parameters - if (NULL != (data = KALDI_MEMALIGN(16, size, &temp))) { - MatrixBase::data_ = static_cast (data); - MatrixBase::num_rows_ = rows; - MatrixBase::num_cols_ = cols; - MatrixBase::stride_ = (stride_type == kDefaultStride ? stride : cols); - } else { - throw std::bad_alloc(); - } + if (rows * cols == 0) { + KALDI_ASSERT(rows == 0 && cols == 0); + this->num_rows_ = 0; + this->num_cols_ = 0; + this->stride_ = 0; + this->data_ = NULL; + return; + } + KALDI_ASSERT(rows > 0 && cols > 0); + MatrixIndexT skip, stride; + size_t size; + void *data; // aligned memory block + void *temp; // memory block to be really freed + + // compute the size of skip and real cols + skip = ((16 / sizeof(Real)) - cols % (16 / sizeof(Real))) % + (16 / sizeof(Real)); + stride = cols + skip; + size = + static_cast(rows) * static_cast(stride) * sizeof(Real); + + // allocate the memory and set the right dimensions and parameters + if (NULL != (data = KALDI_MEMALIGN(16, size, &temp))) { + MatrixBase::data_ = static_cast(data); + MatrixBase::num_rows_ = rows; + MatrixBase::num_cols_ = cols; + MatrixBase::stride_ = + (stride_type == kDefaultStride ? stride : cols); + } else { + throw std::bad_alloc(); + } } -template +template void Matrix::Resize(const MatrixIndexT rows, const MatrixIndexT cols, MatrixResizeType resize_type, MatrixStrideType stride_type) { - // the next block uses recursion to handle what we have to do if - // resize_type == kCopyData. - if (resize_type == kCopyData) { - if (this->data_ == NULL || rows == 0) resize_type = kSetZero; // nothing to copy. - else if (rows == this->num_rows_ && cols == this->num_cols_ && - (stride_type == kDefaultStride || this->stride_ == this->num_cols_)) { return; } // nothing to do. - else { - // set tmp to a matrix of the desired size; if new matrix - // is bigger in some dimension, zero it. - MatrixResizeType new_resize_type = - (rows > this->num_rows_ || cols > this->num_cols_) ? kSetZero : kUndefined; - Matrix tmp(rows, cols, new_resize_type, stride_type); - MatrixIndexT rows_min = std::min(rows, this->num_rows_), - cols_min = std::min(cols, this->num_cols_); - tmp.Range(0, rows_min, 0, cols_min). - CopyFromMat(this->Range(0, rows_min, 0, cols_min)); - tmp.Swap(this); - // and now let tmp go out of scope, deleting what was in *this. - return; + // the next block uses recursion to handle what we have to do if + // resize_type == kCopyData. + if (resize_type == kCopyData) { + if (this->data_ == NULL || rows == 0) + resize_type = kSetZero; // nothing to copy. + else if (rows == this->num_rows_ && cols == this->num_cols_ && + (stride_type == kDefaultStride || + this->stride_ == this->num_cols_)) { + return; + } // nothing to do. + else { + // set tmp to a matrix of the desired size; if new matrix + // is bigger in some dimension, zero it. + MatrixResizeType new_resize_type = + (rows > this->num_rows_ || cols > this->num_cols_) ? kSetZero + : kUndefined; + Matrix tmp(rows, cols, new_resize_type, stride_type); + MatrixIndexT rows_min = std::min(rows, this->num_rows_), + cols_min = std::min(cols, this->num_cols_); + tmp.Range(0, rows_min, 0, cols_min) + .CopyFromMat(this->Range(0, rows_min, 0, cols_min)); + tmp.Swap(this); + // and now let tmp go out of scope, deleting what was in *this. + return; + } } - } - // At this point, resize_type == kSetZero or kUndefined. + // At this point, resize_type == kSetZero or kUndefined. - if (MatrixBase::data_ != NULL) { - if (rows == MatrixBase::num_rows_ - && cols == MatrixBase::num_cols_) { - if (resize_type == kSetZero) - this->SetZero(); - return; + if (MatrixBase::data_ != NULL) { + if (rows == MatrixBase::num_rows_ && + cols == MatrixBase::num_cols_) { + if (resize_type == kSetZero) this->SetZero(); + return; + } else + Destroy(); } - else - Destroy(); - } - Init(rows, cols, stride_type); - if (resize_type == kSetZero) MatrixBase::SetZero(); + Init(rows, cols, stride_type); + if (resize_type == kSetZero) MatrixBase::SetZero(); } -template -template +template +template void MatrixBase::CopyFromMat(const MatrixBase &M, MatrixTransposeType Trans) { - if (sizeof(Real) == sizeof(OtherReal) && - static_cast(M.Data()) == - static_cast(this->Data())) { - // CopyFromMat called on same data. Nothing to do (except sanity checks). - KALDI_ASSERT(Trans == kNoTrans && M.NumRows() == NumRows() && - M.NumCols() == NumCols() && M.Stride() == Stride()); - return; - } - if (Trans == kNoTrans) { - KALDI_ASSERT(num_rows_ == M.NumRows() && num_cols_ == M.NumCols()); - for (MatrixIndexT i = 0; i < num_rows_; i++) - (*this).Row(i).CopyFromVec(M.Row(i)); - } else { - KALDI_ASSERT(num_cols_ == M.NumRows() && num_rows_ == M.NumCols()); - int32 this_stride = stride_, other_stride = M.Stride(); - Real *this_data = data_; - const OtherReal *other_data = M.Data(); - for (MatrixIndexT i = 0; i < num_rows_; i++) - for (MatrixIndexT j = 0; j < num_cols_; j++) - this_data[i * this_stride + j] = other_data[j * other_stride + i]; - } + if (sizeof(Real) == sizeof(OtherReal) && + static_cast(M.Data()) == + static_cast(this->Data())) { + // CopyFromMat called on same data. Nothing to do (except sanity + // checks). + KALDI_ASSERT(Trans == kNoTrans && M.NumRows() == NumRows() && + M.NumCols() == NumCols() && M.Stride() == Stride()); + return; + } + if (Trans == kNoTrans) { + KALDI_ASSERT(num_rows_ == M.NumRows() && num_cols_ == M.NumCols()); + for (MatrixIndexT i = 0; i < num_rows_; i++) + (*this).Row(i).CopyFromVec(M.Row(i)); + } else { + KALDI_ASSERT(num_cols_ == M.NumRows() && num_rows_ == M.NumCols()); + int32 this_stride = stride_, other_stride = M.Stride(); + Real *this_data = data_; + const OtherReal *other_data = M.Data(); + for (MatrixIndexT i = 0; i < num_rows_; i++) + for (MatrixIndexT j = 0; j < num_cols_; j++) + this_data[i * this_stride + j] = + other_data[j * other_stride + i]; + } } // template instantiations. -template -void MatrixBase::CopyFromMat(const MatrixBase & M, - MatrixTransposeType Trans); -template -void MatrixBase::CopyFromMat(const MatrixBase & M, - MatrixTransposeType Trans); -template -void MatrixBase::CopyFromMat(const MatrixBase & M, - MatrixTransposeType Trans); -template -void MatrixBase::CopyFromMat(const MatrixBase & M, - MatrixTransposeType Trans); - +template void MatrixBase::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans); +template void MatrixBase::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans); +template void MatrixBase::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans); +template void MatrixBase::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans); + +/* // Specialize the template for CopyFromSp for float, float. template<> template<> @@ -992,103 +1014,100 @@ template void MatrixBase::CopyFromTp(const TpMatrix & M, MatrixTransposeType trans); - -template +*/ +template void MatrixBase::CopyRowsFromVec(const VectorBase &rv) { - if (rv.Dim() == num_rows_*num_cols_) { - if (stride_ == num_cols_) { - // one big copy operation. - const Real *rv_data = rv.Data(); - std::memcpy(data_, rv_data, sizeof(Real)*num_rows_*num_cols_); - } else { - const Real *rv_data = rv.Data(); - for (MatrixIndexT r = 0; r < num_rows_; r++) { - Real *row_data = RowData(r); - for (MatrixIndexT c = 0; c < num_cols_; c++) { - row_data[c] = rv_data[c]; + if (rv.Dim() == num_rows_ * num_cols_) { + if (stride_ == num_cols_) { + // one big copy operation. + const Real *rv_data = rv.Data(); + std::memcpy(data_, rv_data, sizeof(Real) * num_rows_ * num_cols_); + } else { + const Real *rv_data = rv.Data(); + for (MatrixIndexT r = 0; r < num_rows_; r++) { + Real *row_data = RowData(r); + for (MatrixIndexT c = 0; c < num_cols_; c++) { + row_data[c] = rv_data[c]; + } + rv_data += num_cols_; + } } - rv_data += num_cols_; - } + } else if (rv.Dim() == num_cols_) { + const Real *rv_data = rv.Data(); + for (MatrixIndexT r = 0; r < num_rows_; r++) + std::memcpy(RowData(r), rv_data, sizeof(Real) * num_cols_); + } else { + KALDI_ERR << "Wrong sized arguments"; } - } else if (rv.Dim() == num_cols_) { - const Real *rv_data = rv.Data(); - for (MatrixIndexT r = 0; r < num_rows_; r++) - std::memcpy(RowData(r), rv_data, sizeof(Real)*num_cols_); - } else { - KALDI_ERR << "Wrong sized arguments"; - } } -template -template +template +template void MatrixBase::CopyRowsFromVec(const VectorBase &rv) { - if (rv.Dim() == num_rows_*num_cols_) { - const OtherReal *rv_data = rv.Data(); - for (MatrixIndexT r = 0; r < num_rows_; r++) { - Real *row_data = RowData(r); - for (MatrixIndexT c = 0; c < num_cols_; c++) { - row_data[c] = static_cast(rv_data[c]); - } - rv_data += num_cols_; + if (rv.Dim() == num_rows_ * num_cols_) { + const OtherReal *rv_data = rv.Data(); + for (MatrixIndexT r = 0; r < num_rows_; r++) { + Real *row_data = RowData(r); + for (MatrixIndexT c = 0; c < num_cols_; c++) { + row_data[c] = static_cast(rv_data[c]); + } + rv_data += num_cols_; + } + } else if (rv.Dim() == num_cols_) { + const OtherReal *rv_data = rv.Data(); + Real *first_row_data = RowData(0); + for (MatrixIndexT c = 0; c < num_cols_; c++) + first_row_data[c] = rv_data[c]; + for (MatrixIndexT r = 1; r < num_rows_; r++) + std::memcpy(RowData(r), first_row_data, sizeof(Real) * num_cols_); + } else { + KALDI_ERR << "Wrong sized arguments."; } - } else if (rv.Dim() == num_cols_) { - const OtherReal *rv_data = rv.Data(); - Real *first_row_data = RowData(0); - for (MatrixIndexT c = 0; c < num_cols_; c++) - first_row_data[c] = rv_data[c]; - for (MatrixIndexT r = 1; r < num_rows_; r++) - std::memcpy(RowData(r), first_row_data, sizeof(Real)*num_cols_); - } else { - KALDI_ERR << "Wrong sized arguments."; - } } -template -void MatrixBase::CopyRowsFromVec(const VectorBase &rv); -template -void MatrixBase::CopyRowsFromVec(const VectorBase &rv); +template void MatrixBase::CopyRowsFromVec(const VectorBase &rv); +template void MatrixBase::CopyRowsFromVec(const VectorBase &rv); -template +template void MatrixBase::CopyColsFromVec(const VectorBase &rv) { - if (rv.Dim() == num_rows_*num_cols_) { - const Real *v_inc_data = rv.Data(); - Real *m_inc_data = data_; + if (rv.Dim() == num_rows_ * num_cols_) { + const Real *v_inc_data = rv.Data(); + Real *m_inc_data = data_; - for (MatrixIndexT c = 0; c < num_cols_; c++) { - for (MatrixIndexT r = 0; r < num_rows_; r++) { - m_inc_data[r * stride_] = v_inc_data[r]; - } - v_inc_data += num_rows_; - m_inc_data ++; - } - } else if (rv.Dim() == num_rows_) { - const Real *v_inc_data = rv.Data(); - Real *m_inc_data = data_; - for (MatrixIndexT r = 0; r < num_rows_; r++) { - Real value = *(v_inc_data++); - for (MatrixIndexT c = 0; c < num_cols_; c++) - m_inc_data[c] = value; - m_inc_data += stride_; + for (MatrixIndexT c = 0; c < num_cols_; c++) { + for (MatrixIndexT r = 0; r < num_rows_; r++) { + m_inc_data[r * stride_] = v_inc_data[r]; + } + v_inc_data += num_rows_; + m_inc_data++; + } + } else if (rv.Dim() == num_rows_) { + const Real *v_inc_data = rv.Data(); + Real *m_inc_data = data_; + for (MatrixIndexT r = 0; r < num_rows_; r++) { + Real value = *(v_inc_data++); + for (MatrixIndexT c = 0; c < num_cols_; c++) m_inc_data[c] = value; + m_inc_data += stride_; + } + } else { + KALDI_ERR << "Wrong size of arguments."; } - } else { - KALDI_ERR << "Wrong size of arguments."; - } } +template +void MatrixBase::CopyRowFromVec(const VectorBase &rv, + const MatrixIndexT row) { + KALDI_ASSERT(rv.Dim() == num_cols_ && + static_cast(row) < + static_cast(num_rows_)); -template -void MatrixBase::CopyRowFromVec(const VectorBase &rv, const MatrixIndexT row) { - KALDI_ASSERT(rv.Dim() == num_cols_ && - static_cast(row) < - static_cast(num_rows_)); - - const Real *rv_data = rv.Data(); - Real *row_data = RowData(row); + const Real *rv_data = rv.Data(); + Real *row_data = RowData(row); - std::memcpy(row_data, rv_data, num_cols_ * sizeof(Real)); + std::memcpy(row_data, rv_data, num_cols_ * sizeof(Real)); } - +/* template void MatrixBase::CopyDiagFromVec(const VectorBase &rv) { KALDI_ASSERT(rv.Dim() == std::min(num_cols_, num_rows_)); @@ -1096,46 +1115,46 @@ void MatrixBase::CopyDiagFromVec(const VectorBase &rv) { Real *my_data = this->Data(); for (; rv_data != rv_end; rv_data++, my_data += (this->stride_+1)) *my_data = *rv_data; -} +}*/ -template +template void MatrixBase::CopyColFromVec(const VectorBase &rv, const MatrixIndexT col) { - KALDI_ASSERT(rv.Dim() == num_rows_ && - static_cast(col) < - static_cast(num_cols_)); + KALDI_ASSERT(rv.Dim() == num_rows_ && + static_cast(col) < + static_cast(num_cols_)); - const Real *rv_data = rv.Data(); - Real *col_data = data_ + col; + const Real *rv_data = rv.Data(); + Real *col_data = data_ + col; - for (MatrixIndexT r = 0; r < num_rows_; r++) - col_data[r * stride_] = rv_data[r]; + for (MatrixIndexT r = 0; r < num_rows_; r++) + col_data[r * stride_] = rv_data[r]; } - -template +template void Matrix::RemoveRow(MatrixIndexT i) { - KALDI_ASSERT(static_cast(i) < - static_cast(MatrixBase::num_rows_) - && "Access out of matrix"); - for (MatrixIndexT j = i + 1; j < MatrixBase::num_rows_; j++) - MatrixBase::Row(j-1).CopyFromVec( MatrixBase::Row(j)); - MatrixBase::num_rows_--; + KALDI_ASSERT( + static_cast(i) < + static_cast(MatrixBase::num_rows_) && + "Access out of matrix"); + for (MatrixIndexT j = i + 1; j < MatrixBase::num_rows_; j++) + MatrixBase::Row(j - 1).CopyFromVec(MatrixBase::Row(j)); + MatrixBase::num_rows_--; } -template +template void Matrix::Destroy() { - // we need to free the data block if it was defined - if (NULL != MatrixBase::data_) - KALDI_MEMALIGN_FREE( MatrixBase::data_); - MatrixBase::data_ = NULL; - MatrixBase::num_rows_ = MatrixBase::num_cols_ - = MatrixBase::stride_ = 0; + // we need to free the data block if it was defined + if (NULL != MatrixBase::data_) + KALDI_MEMALIGN_FREE(MatrixBase::data_); + MatrixBase::data_ = NULL; + MatrixBase::num_rows_ = MatrixBase::num_cols_ = + MatrixBase::stride_ = 0; } - +/* template void MatrixBase::MulElements(const MatrixBase &a) { KALDI_ASSERT(a.NumRows() == num_rows_ && a.NumCols() == num_cols_); @@ -1255,7 +1274,8 @@ template void MatrixBase::GroupPnormDeriv(const MatrixBase &input, const MatrixBase &output, Real power) { - KALDI_ASSERT(input.NumCols() == this->NumCols() && input.NumRows() == this->NumRows()); + KALDI_ASSERT(input.NumCols() == this->NumCols() && input.NumRows() == +this->NumRows()); KALDI_ASSERT(this->NumCols() % output.NumCols() == 0 && this->NumRows() == output.NumRows()); @@ -1325,25 +1345,27 @@ void MatrixBase::MulColsVec(const VectorBase &scale) { } } } +*/ -template +template void MatrixBase::SetZero() { - if (num_cols_ == stride_) - memset(data_, 0, sizeof(Real)*num_rows_*num_cols_); - else - for (MatrixIndexT row = 0; row < num_rows_; row++) - memset(data_ + row*stride_, 0, sizeof(Real)*num_cols_); + if (num_cols_ == stride_) + memset(data_, 0, sizeof(Real) * num_rows_ * num_cols_); + else + for (MatrixIndexT row = 0; row < num_rows_; row++) + memset(data_ + row * stride_, 0, sizeof(Real) * num_cols_); } -template +template void MatrixBase::Set(Real value) { - for (MatrixIndexT row = 0; row < num_rows_; row++) { - for (MatrixIndexT col = 0; col < num_cols_; col++) { - (*this)(row, col) = value; + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < num_cols_; col++) { + (*this)(row, col) = value; + } } - } } +/* template void MatrixBase::SetUnit() { SetZero(); @@ -1360,7 +1382,8 @@ void MatrixBase::SetRandn() { for (MatrixIndexT col = 0; col < nc; col += 2) { kaldi::RandGauss2(row_data + col, row_data + col + 1, &rstate); } - if (nc != num_cols_) row_data[nc] = static_cast(kaldi::RandGauss(&rstate)); + if (nc != num_cols_) row_data[nc] = +static_cast(kaldi::RandGauss(&rstate)); } } @@ -1374,305 +1397,307 @@ void MatrixBase::SetRandUniform() { } } } +*/ -template +template void MatrixBase::Write(std::ostream &os, bool binary) const { - if (!os.good()) { - KALDI_ERR << "Failed to write matrix to stream: stream not good"; - } - if (binary) { // Use separate binary and text formats, - // since in binary mode we need to know if it's float or double. - std::string my_token = (sizeof(Real) == 4 ? "FM" : "DM"); - - WriteToken(os, binary, my_token); - { - int32 rows = this->num_rows_; // make the size 32-bit on disk. - int32 cols = this->num_cols_; - KALDI_ASSERT(this->num_rows_ == (MatrixIndexT) rows); - KALDI_ASSERT(this->num_cols_ == (MatrixIndexT) cols); - WriteBasicType(os, binary, rows); - WriteBasicType(os, binary, cols); - } - if (Stride() == NumCols()) - os.write(reinterpret_cast (Data()), sizeof(Real) - * static_cast(num_rows_) * static_cast(num_cols_)); - else - for (MatrixIndexT i = 0; i < num_rows_; i++) - os.write(reinterpret_cast (RowData(i)), sizeof(Real) - * num_cols_); if (!os.good()) { - KALDI_ERR << "Failed to write matrix to stream"; - } - } else { // text mode. - if (num_cols_ == 0) { - os << " [ ]\n"; - } else { - os << " ["; - for (MatrixIndexT i = 0; i < num_rows_; i++) { - os << "\n "; - for (MatrixIndexT j = 0; j < num_cols_; j++) - os << (*this)(i, j) << " "; - } - os << "]\n"; + KALDI_ERR << "Failed to write matrix to stream: stream not good"; + } + if (binary) { // Use separate binary and text formats, + // since in binary mode we need to know if it's float or double. + std::string my_token = (sizeof(Real) == 4 ? "FM" : "DM"); + + WriteToken(os, binary, my_token); + { + int32 rows = this->num_rows_; // make the size 32-bit on disk. + int32 cols = this->num_cols_; + KALDI_ASSERT(this->num_rows_ == (MatrixIndexT)rows); + KALDI_ASSERT(this->num_cols_ == (MatrixIndexT)cols); + WriteBasicType(os, binary, rows); + WriteBasicType(os, binary, cols); + } + if (Stride() == NumCols()) + os.write(reinterpret_cast(Data()), + sizeof(Real) * static_cast(num_rows_) * + static_cast(num_cols_)); + else + for (MatrixIndexT i = 0; i < num_rows_; i++) + os.write(reinterpret_cast(RowData(i)), + sizeof(Real) * num_cols_); + if (!os.good()) { + KALDI_ERR << "Failed to write matrix to stream"; + } + } else { // text mode. + if (num_cols_ == 0) { + os << " [ ]\n"; + } else { + os << " ["; + for (MatrixIndexT i = 0; i < num_rows_; i++) { + os << "\n "; + for (MatrixIndexT j = 0; j < num_cols_; j++) + os << (*this)(i, j) << " "; + } + os << "]\n"; + } } - } -} - - -template -void MatrixBase::Read(std::istream & is, bool binary, bool add) { - if (add) { - Matrix tmp(num_rows_, num_cols_); - tmp.Read(is, binary, false); // read without adding. - if (tmp.num_rows_ != this->num_rows_ || tmp.num_cols_ != this->num_cols_) - KALDI_ERR << "MatrixBase::Read, size mismatch " - << this->num_rows_ << ", " << this->num_cols_ - << " vs. " << tmp.num_rows_ << ", " << tmp.num_cols_; - this->AddMat(1.0, tmp); - return; - } - // now assume add == false. - - // In order to avoid rewriting this, we just declare a Matrix and - // use it to read the data, then copy. - Matrix tmp; - tmp.Read(is, binary, false); - if (tmp.NumRows() != NumRows() || tmp.NumCols() != NumCols()) { - KALDI_ERR << "MatrixBase::Read, size mismatch " - << NumRows() << " x " << NumCols() << " versus " - << tmp.NumRows() << " x " << tmp.NumCols(); - } - CopyFromMat(tmp); } -template -void Matrix::Read(std::istream & is, bool binary, bool add) { - if (add) { +template +void MatrixBase::Read(std::istream &is, bool binary) { + // In order to avoid rewriting this, we just declare a Matrix and + // use it to read the data, then copy. Matrix tmp; - tmp.Read(is, binary, false); // read without adding. - if (this->num_rows_ == 0) this->Resize(tmp.num_rows_, tmp.num_cols_); - else { - if (this->num_rows_ != tmp.num_rows_ || this->num_cols_ != tmp.num_cols_) { - if (tmp.num_rows_ == 0) return; // do nothing in this case. - else KALDI_ERR << "Matrix::Read, size mismatch " - << this->num_rows_ << ", " << this->num_cols_ - << " vs. " << tmp.num_rows_ << ", " << tmp.num_cols_; - } + tmp.Read(is, binary); + if (tmp.NumRows() != NumRows() || tmp.NumCols() != NumCols()) { + KALDI_ERR << "MatrixBase::Read, size mismatch " << NumRows() + << " x " << NumCols() << " versus " << tmp.NumRows() << " x " + << tmp.NumCols(); } - this->AddMat(1.0, tmp); - return; - } + CopyFromMat(tmp); +} - // now assume add == false. - MatrixIndexT pos_at_start = is.tellg(); - std::ostringstream specific_error; - if (binary) { // Read in binary mode. - int peekval = Peek(is, binary); - if (peekval == 'C') { - // This code enables us to read CompressedMatrix as a regular matrix. - CompressedMatrix compressed_mat; - compressed_mat.Read(is, binary); // at this point, add == false. - this->Resize(compressed_mat.NumRows(), compressed_mat.NumCols()); - compressed_mat.CopyToMat(this); - return; - } - const char *my_token = (sizeof(Real) == 4 ? "FM" : "DM"); - char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); - if (peekval == other_token_start) { // need to instantiate the other type to read it. - typedef typename OtherReal::Real OtherType; // if Real == float, OtherType == double, and vice versa. - Matrix other(this->num_rows_, this->num_cols_); - other.Read(is, binary, false); // add is false at this point anyway. - this->Resize(other.NumRows(), other.NumCols()); - this->CopyFromMat(other); - return; - } - std::string token; - ReadToken(is, binary, &token); - if (token != my_token) { - if (token.length() > 20) token = token.substr(0, 17) + "..."; - specific_error << ": Expected token " << my_token << ", got " << token; - goto bad; - } - int32 rows, cols; - ReadBasicType(is, binary, &rows); // throws on error. - ReadBasicType(is, binary, &cols); // throws on error. - if ((MatrixIndexT)rows != this->num_rows_ || (MatrixIndexT)cols != this->num_cols_) { - this->Resize(rows, cols); - } - if (this->Stride() == this->NumCols() && rows*cols!=0) { - is.read(reinterpret_cast(this->Data()), - sizeof(Real)*rows*cols); - if (is.fail()) goto bad; - } else { - for (MatrixIndexT i = 0; i < (MatrixIndexT)rows; i++) { - is.read(reinterpret_cast(this->RowData(i)), sizeof(Real)*cols); - if (is.fail()) goto bad; - } - } - if (is.eof()) return; - if (is.fail()) goto bad; - return; - } else { // Text mode. - std::string str; - is >> str; // get a token - if (is.fail()) { specific_error << ": Expected \"[\", got EOF"; goto bad; } - // if ((str.compare("DM") == 0) || (str.compare("FM") == 0)) { // Back compatibility. - // is >> str; // get #rows - // is >> str; // get #cols - // is >> str; // get "[" - // } - if (str == "[]") { Resize(0, 0); return; } // Be tolerant of variants. - else if (str != "[") { - if (str.length() > 20) str = str.substr(0, 17) + "..."; - specific_error << ": Expected \"[\", got \"" << str << '"'; - goto bad; - } - // At this point, we have read "[". - std::vector* > data; - std::vector *cur_row = new std::vector; - while (1) { - int i = is.peek(); - if (i == -1) { specific_error << "Got EOF while reading matrix data"; goto cleanup; } - else if (static_cast(i) == ']') { // Finished reading matrix. - is.get(); // eat the "]". - i = is.peek(); - if (static_cast(i) == '\r') { - is.get(); - is.get(); // get \r\n (must eat what we wrote) - } else if (static_cast(i) == '\n') { is.get(); } // get \n (must eat what we wrote) - if (is.fail()) { - KALDI_WARN << "After end of matrix data, read error."; - // we got the data we needed, so just warn for this error. +template +void Matrix::Read(std::istream &is, bool binary) { + // now assume add == false. + MatrixIndexT pos_at_start = is.tellg(); + std::ostringstream specific_error; + + if (binary) { // Read in binary mode. + int peekval = Peek(is, binary); + if (peekval == 'C') { + // This code enables us to read CompressedMatrix as a regular + // matrix. + // CompressedMatrix compressed_mat; + // compressed_mat.Read(is, binary); // at this point, add == false. + // this->Resize(compressed_mat.NumRows(), compressed_mat.NumCols()); + // compressed_mat.CopyToMat(this); + return; } - // Now process the data. - if (!cur_row->empty()) data.push_back(cur_row); - else delete(cur_row); - cur_row = NULL; - if (data.empty()) { this->Resize(0, 0); return; } - else { - int32 num_rows = data.size(), num_cols = data[0]->size(); - this->Resize(num_rows, num_cols); - for (int32 i = 0; i < num_rows; i++) { - if (static_cast(data[i]->size()) != num_cols) { - specific_error << "Matrix has inconsistent #cols: " << num_cols - << " vs." << data[i]->size() << " (processing row" - << i << ")"; - goto cleanup; + const char *my_token = (sizeof(Real) == 4 ? "FM" : "DM"); + char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); + if (peekval == other_token_start) { // need to instantiate the other + // type to read it. + typedef typename OtherReal::Real OtherType; // if Real == + // float, + // OtherType == + // double, and + // vice versa. + Matrix other(this->num_rows_, this->num_cols_); + other.Read(is, binary); // add is false at this point anyway. + this->Resize(other.NumRows(), other.NumCols()); + this->CopyFromMat(other); + return; + } + std::string token; + ReadToken(is, binary, &token); + if (token != my_token) { + if (token.length() > 20) token = token.substr(0, 17) + "..."; + specific_error << ": Expected token " << my_token << ", got " + << token; + goto bad; + } + int32 rows, cols; + ReadBasicType(is, binary, &rows); // throws on error. + ReadBasicType(is, binary, &cols); // throws on error. + if ((MatrixIndexT)rows != this->num_rows_ || + (MatrixIndexT)cols != this->num_cols_) { + this->Resize(rows, cols); + } + if (this->Stride() == this->NumCols() && rows * cols != 0) { + is.read(reinterpret_cast(this->Data()), + sizeof(Real) * rows * cols); + if (is.fail()) goto bad; + } else { + for (MatrixIndexT i = 0; i < (MatrixIndexT)rows; i++) { + is.read(reinterpret_cast(this->RowData(i)), + sizeof(Real) * cols); + if (is.fail()) goto bad; } - for (int32 j = 0; j < num_cols; j++) - (*this)(i, j) = (*(data[i]))[j]; - delete data[i]; - data[i] = NULL; - } } + if (is.eof()) return; + if (is.fail()) goto bad; return; - } else if (static_cast(i) == '\n' || static_cast(i) == ';') { - // End of matrix row. - is.get(); - if (cur_row->size() != 0) { - data.push_back(cur_row); - cur_row = new std::vector; - cur_row->reserve(data.back()->size()); - } - } else if ( (i >= '0' && i <= '9') || i == '-' ) { // A number... - Real r; - is >> r; + } else { // Text mode. + std::string str; + is >> str; // get a token if (is.fail()) { - specific_error << "Stream failure/EOF while reading matrix data."; - goto cleanup; + specific_error << ": Expected \"[\", got EOF"; + goto bad; } - cur_row->push_back(r); - } else if (isspace(i)) { - is.get(); // eat the space and do nothing. - } else { // NaN or inf or error. - std::string str; - is >> str; - if (!KALDI_STRCASECMP(str.c_str(), "inf") || - !KALDI_STRCASECMP(str.c_str(), "infinity")) { - cur_row->push_back(std::numeric_limits::infinity()); - KALDI_WARN << "Reading infinite value into matrix."; - } else if (!KALDI_STRCASECMP(str.c_str(), "nan")) { - cur_row->push_back(std::numeric_limits::quiet_NaN()); - KALDI_WARN << "Reading NaN value into matrix."; - } else { - if (str.length() > 20) str = str.substr(0, 17) + "..."; - specific_error << "Expecting numeric matrix data, got " << str; - goto cleanup; + // if ((str.compare("DM") == 0) || (str.compare("FM") == 0)) { // Back + // compatibility. + // is >> str; // get #rows + // is >> str; // get #cols + // is >> str; // get "[" + // } + if (str == "[]") { + Resize(0, 0); + return; + } // Be tolerant of variants. + else if (str != "[") { + if (str.length() > 20) str = str.substr(0, 17) + "..."; + specific_error << ": Expected \"[\", got \"" << str << '"'; + goto bad; + } + // At this point, we have read "[". + std::vector *> data; + std::vector *cur_row = new std::vector; + while (1) { + int i = is.peek(); + if (i == -1) { + specific_error << "Got EOF while reading matrix data"; + goto cleanup; + } else if (static_cast(i) == + ']') { // Finished reading matrix. + is.get(); // eat the "]". + i = is.peek(); + if (static_cast(i) == '\r') { + is.get(); + is.get(); // get \r\n (must eat what we wrote) + } else if (static_cast(i) == '\n') { + is.get(); + } // get \n (must eat what we wrote) + if (is.fail()) { + KALDI_WARN << "After end of matrix data, read error."; + // we got the data we needed, so just warn for this error. + } + // Now process the data. + if (!cur_row->empty()) + data.push_back(cur_row); + else + delete (cur_row); + cur_row = NULL; + if (data.empty()) { + this->Resize(0, 0); + return; + } else { + int32 num_rows = data.size(), num_cols = data[0]->size(); + this->Resize(num_rows, num_cols); + for (int32 i = 0; i < num_rows; i++) { + if (static_cast(data[i]->size()) != num_cols) { + specific_error + << "Matrix has inconsistent #cols: " << num_cols + << " vs." << data[i]->size() + << " (processing row" << i << ")"; + goto cleanup; + } + for (int32 j = 0; j < num_cols; j++) + (*this)(i, j) = (*(data[i]))[j]; + delete data[i]; + data[i] = NULL; + } + } + return; + } else if (static_cast(i) == '\n' || + static_cast(i) == ';') { + // End of matrix row. + is.get(); + if (cur_row->size() != 0) { + data.push_back(cur_row); + cur_row = new std::vector; + cur_row->reserve(data.back()->size()); + } + } else if ((i >= '0' && i <= '9') || i == '-') { // A number... + Real r; + is >> r; + if (is.fail()) { + specific_error + << "Stream failure/EOF while reading matrix data."; + goto cleanup; + } + cur_row->push_back(r); + } else if (isspace(i)) { + is.get(); // eat the space and do nothing. + } else { // NaN or inf or error. + std::string str; + is >> str; + if (!KALDI_STRCASECMP(str.c_str(), "inf") || + !KALDI_STRCASECMP(str.c_str(), "infinity")) { + cur_row->push_back(std::numeric_limits::infinity()); + KALDI_WARN << "Reading infinite value into matrix."; + } else if (!KALDI_STRCASECMP(str.c_str(), "nan")) { + cur_row->push_back(std::numeric_limits::quiet_NaN()); + KALDI_WARN << "Reading NaN value into matrix."; + } else { + if (str.length() > 20) str = str.substr(0, 17) + "..."; + specific_error << "Expecting numeric matrix data, got " + << str; + goto cleanup; + } + } } - } - } // Note, we never leave the while () loop before this // line (we return from it.) - cleanup: // We only reach here in case of error in the while loop above. - if(cur_row != NULL) - delete cur_row; - for (size_t i = 0; i < data.size(); i++) - if(data[i] != NULL) - delete data[i]; - // and then go on to "bad" below, where we print error. - } + cleanup: // We only reach here in case of error in the while loop above. + if (cur_row != NULL) delete cur_row; + for (size_t i = 0; i < data.size(); i++) + if (data[i] != NULL) delete data[i]; + // and then go on to "bad" below, where we print error. + } bad: - KALDI_ERR << "Failed to read matrix from stream. " << specific_error.str() - << " File position at start is " - << pos_at_start << ", currently " << is.tellg(); + KALDI_ERR << "Failed to read matrix from stream. " << specific_error.str() + << " File position at start is " << pos_at_start << ", currently " + << is.tellg(); } // Constructor... note that this is not const-safe as it would // be quite complicated to implement a "const SubMatrix" class that // would not allow its contents to be changed. -template +template SubMatrix::SubMatrix(const MatrixBase &M, const MatrixIndexT ro, const MatrixIndexT r, const MatrixIndexT co, const MatrixIndexT c) { - if (r == 0 || c == 0) { - // we support the empty sub-matrix as a special case. - KALDI_ASSERT(c == 0 && r == 0); - this->data_ = NULL; - this->num_cols_ = 0; - this->num_rows_ = 0; - this->stride_ = 0; - return; - } - KALDI_ASSERT(static_cast(ro) < - static_cast(M.num_rows_) && - static_cast(co) < - static_cast(M.num_cols_) && - static_cast(r) <= - static_cast(M.num_rows_ - ro) && - static_cast(c) <= - static_cast(M.num_cols_ - co)); - // point to the begining of window - MatrixBase::num_rows_ = r; - MatrixBase::num_cols_ = c; - MatrixBase::stride_ = M.Stride(); - MatrixBase::data_ = M.Data_workaround() + - static_cast(co) + - static_cast(ro) * static_cast(M.Stride()); + if (r == 0 || c == 0) { + // we support the empty sub-matrix as a special case. + KALDI_ASSERT(c == 0 && r == 0); + this->data_ = NULL; + this->num_cols_ = 0; + this->num_rows_ = 0; + this->stride_ = 0; + return; + } + KALDI_ASSERT(static_cast(ro) < + static_cast(M.num_rows_) && + static_cast(co) < + static_cast(M.num_cols_) && + static_cast(r) <= + static_cast(M.num_rows_ - ro) && + static_cast(c) <= + static_cast(M.num_cols_ - co)); + // point to the begining of window + MatrixBase::num_rows_ = r; + MatrixBase::num_cols_ = c; + MatrixBase::stride_ = M.Stride(); + MatrixBase::data_ = + M.Data_workaround() + static_cast(co) + + static_cast(ro) * static_cast(M.Stride()); } -template +template SubMatrix::SubMatrix(Real *data, MatrixIndexT num_rows, MatrixIndexT num_cols, - MatrixIndexT stride): - MatrixBase(data, num_cols, num_rows, stride) { // caution: reversed order! - if (data == NULL) { - KALDI_ASSERT(num_rows * num_cols == 0); - this->num_rows_ = 0; - this->num_cols_ = 0; - this->stride_ = 0; - } else { - KALDI_ASSERT(this->stride_ >= this->num_cols_); - } + MatrixIndexT stride) + : MatrixBase( + data, num_cols, num_rows, stride) { // caution: reversed order! + if (data == NULL) { + KALDI_ASSERT(num_rows * num_cols == 0); + this->num_rows_ = 0; + this->num_cols_ = 0; + this->stride_ = 0; + } else { + KALDI_ASSERT(this->stride_ >= this->num_cols_); + } } - +/* template void MatrixBase::Add(const Real alpha) { Real *data = data_; @@ -1697,9 +1722,11 @@ Real MatrixBase::Cond() const { KALDI_ASSERT(num_rows_ > 0&&num_cols_ > 0); Vector singular_values(std::min(num_rows_, num_cols_)); Svd(&singular_values); // Get singular values... - Real min = singular_values(0), max = singular_values(0); // both absolute values... + Real min = singular_values(0), max = singular_values(0); // both absolute +values... for (MatrixIndexT i = 1;i < singular_values.Dim();i++) { - min = std::min((Real)std::abs(singular_values(i)), min); max = std::max((Real)std::abs(singular_values(i)), max); + min = std::min((Real)std::abs(singular_values(i)), min); max = +std::max((Real)std::abs(singular_values(i)), max); } if (min > 0) return max/min; else return std::numeric_limits::infinity(); @@ -1709,7 +1736,8 @@ template Real MatrixBase::Trace(bool check_square) const { KALDI_ASSERT(!check_square || num_rows_ == num_cols_); Real ans = 0.0; - for (MatrixIndexT r = 0;r < std::min(num_rows_, num_cols_);r++) ans += data_ [r + stride_*r]; + for (MatrixIndexT r = 0;r < std::min(num_rows_, num_cols_);r++) ans += data_ +[r + stride_*r]; return ans; } @@ -1739,22 +1767,29 @@ Real MatrixBase::Min() const { template void MatrixBase::AddMatMatMat(Real alpha, - const MatrixBase &A, MatrixTransposeType transA, - const MatrixBase &B, MatrixTransposeType transB, - const MatrixBase &C, MatrixTransposeType transC, + const MatrixBase &A, +MatrixTransposeType transA, + const MatrixBase &B, +MatrixTransposeType transB, + const MatrixBase &C, +MatrixTransposeType transC, Real beta) { - // Note on time taken with different orders of computation. Assume not transposed in this / - // discussion. Firstly, normalize expressions using A.NumCols == B.NumRows and B.NumCols == C.NumRows, prefer + // Note on time taken with different orders of computation. Assume not +transposed in this / + // discussion. Firstly, normalize expressions using A.NumCols == B.NumRows and +B.NumCols == C.NumRows, prefer // rows where there is a choice. // time taken for (AB) is: A.NumRows*B.NumRows*C.Rows // time taken for (AB)C is A.NumRows*C.NumRows*C.Cols - // so this order is A.NumRows*B.NumRows*C.NumRows + A.NumRows*C.NumRows*C.NumCols. + // so this order is A.NumRows*B.NumRows*C.NumRows + +A.NumRows*C.NumRows*C.NumCols. // time taken for (BC) is: B.NumRows*C.NumRows*C.Cols // time taken for A(BC) is: A.NumRows*B.NumRows*C.Cols // so this order is B.NumRows*C.NumRows*C.NumCols + A.NumRows*B.NumRows*C.Cols - MatrixIndexT ARows = A.num_rows_, ACols = A.num_cols_, BRows = B.num_rows_, BCols = B.num_cols_, + MatrixIndexT ARows = A.num_rows_, ACols = A.num_cols_, BRows = B.num_rows_, +BCols = B.num_cols_, CRows = C.num_rows_, CCols = C.num_cols_; if (transA == kTrans) std::swap(ARows, ACols); if (transB == kTrans) std::swap(BRows, BCols); @@ -1778,58 +1813,71 @@ void MatrixBase::AddMatMatMat(Real alpha, template -void MatrixBase::DestructiveSvd(VectorBase *s, MatrixBase *U, MatrixBase *Vt) { +void MatrixBase::DestructiveSvd(VectorBase *s, MatrixBase *U, +MatrixBase *Vt) { // Svd, *this = U*diag(s)*Vt. // With (*this).num_rows_ == m, (*this).num_cols_ == n, - // Support only skinny Svd with m>=n (NumRows>=NumCols), and zero sizes for U and Vt mean + // Support only skinny Svd with m>=n (NumRows>=NumCols), and zero sizes for U +and Vt mean // we do not want that output. We expect that s.Dim() == m, // U is either 0 by 0 or m by n, and rv is either 0 by 0 or n by n. // Throws exception on error. - KALDI_ASSERT(num_rows_>=num_cols_ && "Svd requires that #rows by >= #cols."); // For compatibility with JAMA code. + KALDI_ASSERT(num_rows_>=num_cols_ && "Svd requires that #rows by >= #cols."); +// For compatibility with JAMA code. KALDI_ASSERT(s->Dim() == num_cols_); // s should be the smaller dim. - KALDI_ASSERT(U == NULL || (U->num_rows_ == num_rows_&&U->num_cols_ == num_cols_)); - KALDI_ASSERT(Vt == NULL || (Vt->num_rows_ == num_cols_&&Vt->num_cols_ == num_cols_)); + KALDI_ASSERT(U == NULL || (U->num_rows_ == num_rows_&&U->num_cols_ == +num_cols_)); + KALDI_ASSERT(Vt == NULL || (Vt->num_rows_ == num_cols_&&Vt->num_cols_ == +num_cols_)); Real prescale = 1.0; - if ( std::abs((*this)(0, 0) ) < 1.0e-30) { // Very tiny value... can cause problems in Svd. + if ( std::abs((*this)(0, 0) ) < 1.0e-30) { // Very tiny value... can cause +problems in Svd. Real max_elem = LargestAbsElem(); if (max_elem != 0) { prescale = 1.0 / max_elem; - if (std::abs(prescale) == std::numeric_limits::infinity()) { prescale = 1.0e+40; } + if (std::abs(prescale) == std::numeric_limits::infinity()) { +prescale = 1.0e+40; } (*this).Scale(prescale); } } #if !defined(HAVE_ATLAS) && !defined(USE_KALDI_SVD) - // "S" == skinny Svd (only one we support because of compatibility with Jama one which is only skinny), + // "S" == skinny Svd (only one we support because of compatibility with Jama +one which is only skinny), // "N"== no eigenvectors wanted. LapackGesvd(s, U, Vt); #else /* if (num_rows_ > 1 && num_cols_ > 1 && (*this)(0, 0) == (*this)(1, 1) - && Max() == Min() && (*this)(0, 0) != 0.0) { // special case that JamaSvd sometimes crashes on. - KALDI_WARN << "Jama SVD crashes on this type of matrix, perturbing it to prevent crash."; + && Max() == Min() && (*this)(0, 0) != 0.0) { // special case that JamaSvd +sometimes crashes on. + KALDI_WARN << "Jama SVD crashes on this type of matrix, perturbing it to +prevent crash."; for(int32 i = 0; i < NumRows(); i++) (*this)(i, i) *= 1.00001; }*/ - bool ans = JamaSvd(s, U, Vt); - if (Vt != NULL) Vt->Transpose(); // possibly to do: change this and also the transpose inside the JamaSvd routine. note, Vt is square. - if (!ans) { - KALDI_ERR << "Error doing Svd"; // This one will be caught. - } -#endif - if (prescale != 1.0) s->Scale(1.0/prescale); -} - -template -void MatrixBase::Svd(VectorBase *s, MatrixBase *U, MatrixBase *Vt) const { +// bool ans = JamaSvd(s, U, Vt); +// if (Vt != NULL) Vt->Transpose(); // possibly to do: change this and also the +// transpose inside the JamaSvd routine. note, Vt is square. +// if (!ans) { +// KALDI_ERR << "Error doing Svd"; // This one will be caught. +//} +//#endif +// if (prescale != 1.0) s->Scale(1.0/prescale); +//} +/* +template +void MatrixBase::Svd(VectorBase *s, MatrixBase *U, +MatrixBase *Vt) const { try { if (num_rows_ >= num_cols_) { Matrix tmp(*this); tmp.DestructiveSvd(s, U, Vt); } else { Matrix tmp(*this, kTrans); // transpose of *this. - // rVt will have different dim so cannot transpose in-place --> use a temp matrix. + // rVt will have different dim so cannot transpose in-place --> use a temp +matrix. Matrix Vt_Trans(Vt ? Vt->num_cols_ : 0, Vt ? Vt->num_rows_ : 0); // U will be transpose tmp.DestructiveSvd(s, Vt ? &Vt_Trans : NULL, U); @@ -1838,7 +1886,8 @@ void MatrixBase::Svd(VectorBase *s, MatrixBase *U, MatrixBase< } } catch (...) { KALDI_ERR << "Error doing Svd (did not converge), first part of matrix is\n" - << SubMatrix(*this, 0, std::min((MatrixIndexT)10, num_rows_), + << SubMatrix(*this, 0, std::min((MatrixIndexT)10, +num_rows_), 0, std::min((MatrixIndexT)10, num_cols_)) << ", min and max are: " << Min() << ", " << Max(); } @@ -1851,7 +1900,8 @@ bool MatrixBase::IsSymmetric(Real cutoff) const { Real bad_sum = 0.0, good_sum = 0.0; for (MatrixIndexT i = 0;i < R;i++) { for (MatrixIndexT j = 0;j < i;j++) { - Real a = (*this)(i, j), b = (*this)(j, i), avg = 0.5*(a+b), diff = 0.5*(a-b); + Real a = (*this)(i, j), b = (*this)(j, i), avg = 0.5*(a+b), diff = +0.5*(a-b); good_sum += std::abs(avg); bad_sum += std::abs(diff); } good_sum += std::abs((*this)(i, i)); @@ -1892,7 +1942,8 @@ bool MatrixBase::IsUnit(Real cutoff) const { Real bad_max = 0.0; for (MatrixIndexT i = 0; i < R;i++) for (MatrixIndexT j = 0; j < C;j++) - bad_max = std::max(bad_max, static_cast(std::abs( (*this)(i, j) - (i == j?1.0:0.0)))); + bad_max = std::max(bad_max, static_cast(std::abs( (*this)(i, j) - (i +== j?1.0:0.0)))); return (bad_max <= cutoff); } @@ -1912,7 +1963,8 @@ Real MatrixBase::FrobeniusNorm() const{ } template -bool MatrixBase::ApproxEqual(const MatrixBase &other, float tol) const { +bool MatrixBase::ApproxEqual(const MatrixBase &other, float tol) +const { if (num_rows_ != other.num_rows_ || num_cols_ != other.num_cols_) KALDI_ERR << "ApproxEqual: size mismatch."; Matrix tmp(*this); @@ -1985,27 +2037,35 @@ void MatrixBase::OrthogonalizeRows() { } -// Uses Svd to compute the eigenvalue decomposition of a symmetric positive semidefinite +// Uses Svd to compute the eigenvalue decomposition of a symmetric positive +semidefinite // matrix: -// (*this) = rU * diag(rs) * rU^T, with rU an orthogonal matrix so rU^{-1} = rU^T. -// Does this by computing svd (*this) = U diag(rs) V^T ... answer is just U diag(rs) U^T. -// Throws exception if this failed to within supplied precision (typically because *this was not +// (*this) = rU * diag(rs) * rU^T, with rU an orthogonal matrix so rU^{-1} = +rU^T. +// Does this by computing svd (*this) = U diag(rs) V^T ... answer is just U +diag(rs) U^T. +// Throws exception if this failed to within supplied precision (typically +because *this was not // symmetric positive definite). template -void MatrixBase::SymPosSemiDefEig(VectorBase *rs, MatrixBase *rU, Real check_thresh) // e.g. check_thresh = 0.001 +void MatrixBase::SymPosSemiDefEig(VectorBase *rs, MatrixBase +*rU, Real check_thresh) // e.g. check_thresh = 0.001 { const MatrixIndexT D = num_rows_; KALDI_ASSERT(num_rows_ == num_cols_); - KALDI_ASSERT(IsSymmetric() && "SymPosSemiDefEig: expecting input to be symmetrical."); + KALDI_ASSERT(IsSymmetric() && "SymPosSemiDefEig: expecting input to be +symmetrical."); KALDI_ASSERT(rU->num_rows_ == D && rU->num_cols_ == D && rs->Dim() == D); Matrix Vt(D, D); Svd(rs, rU, &Vt); - // First just zero any singular values if the column of U and V do not have +ve dot product-- - // this may mean we have small negative eigenvalues, and if we zero them the result will be closer to correct. + // First just zero any singular values if the column of U and V do not have ++ve dot product-- + // this may mean we have small negative eigenvalues, and if we zero them the +result will be closer to correct. for (MatrixIndexT i = 0;i < D;i++) { Real sum = 0.0; for (MatrixIndexT j = 0;j < D;j++) sum += (*rU)(j, i) * Vt(i, j); @@ -2024,9 +2084,12 @@ void MatrixBase::SymPosSemiDefEig(VectorBase *rs, MatrixBase * if (!(old_norm == 0 && new_norm == 0)) { float diff_norm = tmpThisFull.FrobeniusNorm(); - if (std::abs(new_norm-old_norm) > old_norm*check_thresh || diff_norm > old_norm*check_thresh) { - KALDI_WARN << "SymPosSemiDefEig seems to have failed " << diff_norm << " !<< " - << check_thresh << "*" << old_norm << ", maybe matrix was not " + if (std::abs(new_norm-old_norm) > old_norm*check_thresh || diff_norm > +old_norm*check_thresh) { + KALDI_WARN << "SymPosSemiDefEig seems to have failed " << diff_norm << " +!<< " + << check_thresh << "*" << old_norm << ", maybe matrix was not +" << "positive semi definite. Continuing anyway."; } } @@ -2038,7 +2101,8 @@ template Real MatrixBase::LogDet(Real *det_sign) const { Real log_det; Matrix tmp(*this); - tmp.Invert(&log_det, det_sign, false); // false== output not needed (saves some computation). + tmp.Invert(&log_det, det_sign, false); // false== output not needed (saves +some computation). return log_det; } @@ -2052,29 +2116,29 @@ void MatrixBase::InvertDouble(Real *log_det, Real *det_sign, if (log_det) *log_det = log_det_tmp; if (det_sign) *det_sign = det_sign_tmp; } +*/ -template -void MatrixBase::CopyFromMat(const CompressedMatrix &mat) { - mat.CopyToMat(this); -} - -template -Matrix::Matrix(const CompressedMatrix &M): MatrixBase() { - Resize(M.NumRows(), M.NumCols(), kUndefined); - M.CopyToMat(this); -} +// template +// void MatrixBase::CopyFromMat(const CompressedMatrix &mat) { +// mat.CopyToMat(this); +//} +// template +// Matrix::Matrix(const CompressedMatrix &M): MatrixBase() { +// Resize(M.NumRows(), M.NumCols(), kUndefined); +// M.CopyToMat(this); +//} -template +template void MatrixBase::InvertElements() { - for (MatrixIndexT r = 0; r < num_rows_; r++) { - for (MatrixIndexT c = 0; c < num_cols_; c++) { - (*this)(r, c) = static_cast(1.0 / (*this)(r, c)); + for (MatrixIndexT r = 0; r < num_rows_; r++) { + for (MatrixIndexT c = 0; c < num_cols_; c++) { + (*this)(r, c) = static_cast(1.0 / (*this)(r, c)); + } } - } } - +/* template void MatrixBase::Transpose() { KALDI_ASSERT(num_rows_ == num_cols_); @@ -2139,7 +2203,8 @@ void MatrixBase::Pow(const MatrixBase &src, Real power) { } template -void MatrixBase::PowAbs(const MatrixBase &src, Real power, bool include_sign) { +void MatrixBase::PowAbs(const MatrixBase &src, Real power, bool +include_sign) { KALDI_ASSERT(SameDim(*this, src)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; Real *row_data = data_; @@ -2148,9 +2213,9 @@ void MatrixBase::PowAbs(const MatrixBase &src, Real power, bool incl row++,row_data += stride_, src_row_data += src.stride_) { for (MatrixIndexT col = 0; col < num_cols; col ++) { if (include_sign == true && src_row_data[col] < 0) { - row_data[col] = -pow(std::abs(src_row_data[col]), power); + row_data[col] = -pow(std::abs(src_row_data[col]), power); } else { - row_data[col] = pow(std::abs(src_row_data[col]), power); + row_data[col] = pow(std::abs(src_row_data[col]), power); } } } @@ -2165,7 +2230,8 @@ void MatrixBase::Floor(const MatrixBase &src, Real floor_val) { for (MatrixIndexT row = 0; row < num_rows; row++,row_data += stride_, src_row_data += src.stride_) { for (MatrixIndexT col = 0; col < num_cols; col++) - row_data[col] = (src_row_data[col] < floor_val ? floor_val : src_row_data[col]); + row_data[col] = (src_row_data[col] < floor_val ? floor_val : +src_row_data[col]); } } @@ -2178,7 +2244,8 @@ void MatrixBase::Ceiling(const MatrixBase &src, Real ceiling_val) { for (MatrixIndexT row = 0; row < num_rows; row++,row_data += stride_, src_row_data += src.stride_) { for (MatrixIndexT col = 0; col < num_cols; col++) - row_data[col] = (src_row_data[col] > ceiling_val ? ceiling_val : src_row_data[col]); + row_data[col] = (src_row_data[col] > ceiling_val ? ceiling_val : +src_row_data[col]); } } @@ -2204,12 +2271,14 @@ void MatrixBase::ExpSpecial(const MatrixBase &src) { for (MatrixIndexT row = 0; row < num_rows; row++,row_data += stride_, src_row_data += src.stride_) { for (MatrixIndexT col = 0; col < num_cols; col++) - row_data[col] = (src_row_data[col] < Real(0) ? kaldi::Exp(src_row_data[col]) : (src_row_data[col] + Real(1))); + row_data[col] = (src_row_data[col] < Real(0) ? +kaldi::Exp(src_row_data[col]) : (src_row_data[col] + Real(1))); } } template -void MatrixBase::ExpLimited(const MatrixBase &src, Real lower_limit, Real upper_limit) { +void MatrixBase::ExpLimited(const MatrixBase &src, Real lower_limit, +Real upper_limit) { KALDI_ASSERT(SameDim(*this, src)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; Real *row_data = data_; @@ -2219,11 +2288,11 @@ void MatrixBase::ExpLimited(const MatrixBase &src, Real lower_limit, for (MatrixIndexT col = 0; col < num_cols; col++) { const Real x = src_row_data[col]; if (!(x >= lower_limit)) - row_data[col] = kaldi::Exp(lower_limit); + row_data[col] = kaldi::Exp(lower_limit); else if (x > upper_limit) - row_data[col] = kaldi::Exp(upper_limit); + row_data[col] = kaldi::Exp(upper_limit); else - row_data[col] = kaldi::Exp(x); + row_data[col] = kaldi::Exp(x); } } } @@ -2250,15 +2319,15 @@ bool MatrixBase::Power(Real power) { (*this).AddMatMat(1.0, tmp, kNoTrans, P, kNoTrans, 0.0); return true; } - -template +*/ +template void Matrix::Swap(Matrix *other) { - std::swap(this->data_, other->data_); - std::swap(this->num_cols_, other->num_cols_); - std::swap(this->num_rows_, other->num_rows_); - std::swap(this->stride_, other->stride_); + std::swap(this->data_, other->data_); + std::swap(this->num_cols_, other->num_cols_); + std::swap(this->num_rows_, other->num_rows_); + std::swap(this->stride_, other->stride_); } - +/* // Repeating this comment that appeared in the header: // Eigenvalue Decomposition of a square NxN matrix into the form (*this) = P D // P^{-1}. Be careful: the relationship of D to the eigenvalues we output is @@ -2269,12 +2338,14 @@ void Matrix::Swap(Matrix *other) { // be block diagonal, with 2x2 blocks corresponding to any such pairs. If a // pair is lambda +- i*mu, D will have a corresponding 2x2 block // [lambda, mu; -mu, lambda]. -// Note that if the input matrix (*this) is non-invertible, P may not be invertible +// Note that if the input matrix (*this) is non-invertible, P may not be +invertible // so in this case instead of the equation (*this) = P D P^{-1} holding, we have // instead (*this) P = P D. // // By making the pointer arguments non-NULL or NULL, the user can choose to take -// not to take the eigenvalues directly, and/or the matrix D which is block-diagonal +// not to take the eigenvalues directly, and/or the matrix D which is +block-diagonal // with 2x2 blocks. template void MatrixBase::Eig(MatrixBase *P, @@ -2298,7 +2369,7 @@ void MatrixBase::Eig(MatrixBase *P, // INT_32 mVersion; // INT_32 mSampSize; // }; - +/* template bool ReadHtk(std::istream &is, Matrix *M_ptr, HtkHeader *header_ptr) { @@ -2400,7 +2471,8 @@ template bool ReadHtk(std::istream &is, Matrix *M, HtkHeader *header_ptr); template -bool WriteHtk(std::ostream &os, const MatrixBase &M, HtkHeader htk_hdr) // header may be derived from a previous call to ReadHtk. Must be in binary mode. +bool WriteHtk(std::ostream &os, const MatrixBase &M, HtkHeader htk_hdr) // +header may be derived from a previous call to ReadHtk. Must be in binary mode. { KALDI_ASSERT(M.NumRows() == static_cast(htk_hdr.mNSamples)); KALDI_ASSERT(M.NumCols() == static_cast(htk_hdr.mSampleSize) / @@ -2502,12 +2574,14 @@ template Real TraceMatMatMat(const MatrixBase &A, MatrixTransposeType transA, const MatrixBase &B, MatrixTransposeType transB, const MatrixBase &C, MatrixTransposeType transC) { - MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), BCols = B.NumCols(), + MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), +BCols = B.NumCols(), CRows = C.NumRows(), CCols = C.NumCols(); if (transA == kTrans) std::swap(ARows, ACols); if (transB == kTrans) std::swap(BRows, BCols); if (transC == kTrans) std::swap(CRows, CCols); - KALDI_ASSERT( CCols == ARows && ACols == BRows && BCols == CRows && "TraceMatMatMat: args have mismatched dimensions."); + KALDI_ASSERT( CCols == ARows && ACols == BRows && BCols == CRows && +"TraceMatMatMat: args have mismatched dimensions."); if (ARows*BCols < std::min(BRows*CCols, CRows*ACols)) { Matrix AB(ARows, BCols); AB.AddMatMat(1.0, A, transA, B, transB, 0.0); // AB = A * B. @@ -2539,13 +2613,16 @@ Real TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, const MatrixBase &B, MatrixTransposeType transB, const MatrixBase &C, MatrixTransposeType transC, const MatrixBase &D, MatrixTransposeType transD) { - MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), BCols = B.NumCols(), - CRows = C.NumRows(), CCols = C.NumCols(), DRows = D.NumRows(), DCols = D.NumCols(); + MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), +BCols = B.NumCols(), + CRows = C.NumRows(), CCols = C.NumCols(), DRows = D.NumRows(), DCols = +D.NumCols(); if (transA == kTrans) std::swap(ARows, ACols); if (transB == kTrans) std::swap(BRows, BCols); if (transC == kTrans) std::swap(CRows, CCols); if (transD == kTrans) std::swap(DRows, DCols); - KALDI_ASSERT( DCols == ARows && ACols == BRows && BCols == CRows && CCols == DRows && "TraceMatMatMat: args have mismatched dimensions."); + KALDI_ASSERT( DCols == ARows && ACols == BRows && BCols == CRows && CCols == +DRows && "TraceMatMatMat: args have mismatched dimensions."); if (ARows*BCols < std::min(BRows*CCols, std::min(CRows*DCols, DRows*ACols))) { Matrix AB(ARows, BCols); AB.AddMatMat(1.0, A, transA, B, transB, 0.0); // AB = A * B. @@ -2572,13 +2649,18 @@ float TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, const MatrixBase &D, MatrixTransposeType transD); template -double TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, - const MatrixBase &B, MatrixTransposeType transB, - const MatrixBase &C, MatrixTransposeType transC, - const MatrixBase &D, MatrixTransposeType transD); +double TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType +transA, + const MatrixBase &B, MatrixTransposeType +transB, + const MatrixBase &C, MatrixTransposeType +transC, + const MatrixBase &D, MatrixTransposeType +transD); template void SortSvd(VectorBase *s, MatrixBase *U, - MatrixBase *Vt, bool sort_on_absolute_value) { + MatrixBase *Vt, bool +sort_on_absolute_value) { /// Makes sure the Svd is sorted (from greatest to least absolute value). MatrixIndexT num_singval = s->Dim(); KALDI_ASSERT(U == NULL || U->NumCols() == num_singval); @@ -2620,7 +2702,8 @@ void SortSvd(VectorBase *s, MatrixBase *U, MatrixBase *Vt, bool); template -void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase &im, +void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase +&im, MatrixBase *D) { MatrixIndexT n = re.Dim(); KALDI_ASSERT(im.Dim() == n && D->NumRows() == n && D->NumCols() == n); @@ -2634,7 +2717,8 @@ void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase & } else { // First of a complex pair KALDI_ASSERT(j+1 < n && ApproxEqual(im(j+1), -im(j)) && ApproxEqual(re(j+1), re(j))); - /// if (im(j) < 0.0) KALDI_WARN << "Negative first im part of pair"; // TEMP + /// if (im(j) < 0.0) KALDI_WARN << "Negative first im part of pair"; // +TEMP Real lambda = re(j), mu = im(j); // create 2x2 block [lambda, mu; -mu, lambda] (*D)(j, j) = lambda; @@ -2647,10 +2731,12 @@ void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase & } template -void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase &im, +void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase +&im, MatrixBase *D); template -void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase &im, +void CreateEigenvalueMatrix(const VectorBase &re, const +VectorBase &im, MatrixBase *D); @@ -2691,7 +2777,8 @@ bool AttemptComplexPower(double *x_re, double *x_im, double power); template Real TraceMatMat(const MatrixBase &A, const MatrixBase &B, - MatrixTransposeType trans) { // tr(A B), equivalent to sum of each element of A times same element in B' + MatrixTransposeType trans) { // tr(A B), equivalent to sum of +each element of A times same element in B' MatrixIndexT aStride = A.stride_, bStride = B.stride_; if (trans == kNoTrans) { KALDI_ASSERT(A.NumRows() == B.NumCols() && A.NumCols() == B.NumRows()); @@ -2821,33 +2908,36 @@ void MatrixBase::GroupMax(const MatrixBase &src) { } } } - -template +*/ +template void MatrixBase::CopyCols(const MatrixBase &src, const MatrixIndexT *indices) { - KALDI_ASSERT(NumRows() == src.NumRows()); - MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, - this_stride = stride_, src_stride = src.stride_; - Real *this_data = this->data_; - const Real *src_data = src.data_; + KALDI_ASSERT(NumRows() == src.NumRows()); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, + this_stride = stride_, src_stride = src.stride_; + Real *this_data = this->data_; + const Real *src_data = src.data_; #ifdef KALDI_PARANOID - MatrixIndexT src_cols = src.NumCols(); - for (MatrixIndexT i = 0; i < num_cols; i++) - KALDI_ASSERT(indices[i] >= -1 && indices[i] < src_cols); + MatrixIndexT src_cols = src.NumCols(); + for (MatrixIndexT i = 0; i < num_cols; i++) + KALDI_ASSERT(indices[i] >= -1 && indices[i] < src_cols); #endif - // For the sake of memory locality we do this row by row, rather - // than doing it column-wise using cublas_Xcopy - for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride, src_data += src_stride) { - const MatrixIndexT *index_ptr = &(indices[0]); - for (MatrixIndexT c = 0; c < num_cols; c++, index_ptr++) { - if (*index_ptr < 0) this_data[c] = 0; - else this_data[c] = src_data[*index_ptr]; + // For the sake of memory locality we do this row by row, rather + // than doing it column-wise using cublas_Xcopy + for (MatrixIndexT r = 0; r < num_rows; + r++, this_data += this_stride, src_data += src_stride) { + const MatrixIndexT *index_ptr = &(indices[0]); + for (MatrixIndexT c = 0; c < num_cols; c++, index_ptr++) { + if (*index_ptr < 0) + this_data[c] = 0; + else + this_data[c] = src_data[*index_ptr]; + } } - } } - +/* template void MatrixBase::AddCols(const MatrixBase &src, const MatrixIndexT *indices) { @@ -2864,15 +2954,17 @@ void MatrixBase::AddCols(const MatrixBase &src, // For the sake of memory locality we do this row by row, rather // than doing it column-wise using cublas_Xcopy - for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride, src_data += src_stride) { + for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride, src_data ++= src_stride) { const MatrixIndexT *index_ptr = &(indices[0]); for (MatrixIndexT c = 0; c < num_cols; c++, index_ptr++) { if (*index_ptr >= 0) this_data[c] += src_data[*index_ptr]; } } -} +}*/ +/* template void MatrixBase::CopyRows(const MatrixBase &src, const MatrixIndexT *indices) { @@ -2995,7 +3087,8 @@ void MatrixBase::DiffSigmoid(const MatrixBase &value, const MatrixBase &diff) { KALDI_ASSERT(SameDim(*this, value) && SameDim(*this, diff)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, - stride = stride_, value_stride = value.stride_, diff_stride = diff.stride_; + stride = stride_, value_stride = value.stride_, diff_stride = +diff.stride_; Real *data = data_; const Real *value_data = value.data_, *diff_data = diff.data_; for (MatrixIndexT r = 0; r < num_rows; r++) { @@ -3012,7 +3105,8 @@ void MatrixBase::DiffTanh(const MatrixBase &value, const MatrixBase &diff) { KALDI_ASSERT(SameDim(*this, value) && SameDim(*this, diff)); MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, - stride = stride_, value_stride = value.stride_, diff_stride = diff.stride_; + stride = stride_, value_stride = value.stride_, diff_stride = +diff.stride_; Real *data = data_; const Real *value_data = value.data_, *diff_data = diff.data_; for (MatrixIndexT r = 0; r < num_rows; r++) { @@ -3022,12 +3116,13 @@ void MatrixBase::DiffTanh(const MatrixBase &value, value_data += value_stride; diff_data += diff_stride; } -} - +}*/ +/* template template -void MatrixBase::AddVecToRows(const Real alpha, const VectorBase &v) { +void MatrixBase::AddVecToRows(const Real alpha, const +VectorBase &v) { const MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, stride = stride_; KALDI_ASSERT(v.Dim() == num_cols); @@ -3058,7 +3153,8 @@ template void MatrixBase::AddVecToRows(const double alpha, template template -void MatrixBase::AddVecToCols(const Real alpha, const VectorBase &v) { +void MatrixBase::AddVecToCols(const Real alpha, const +VectorBase &v) { const MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, stride = stride_; KALDI_ASSERT(v.Dim() == num_rows); @@ -3087,11 +3183,11 @@ template void MatrixBase::AddVecToCols(const double alpha, const VectorBase &v); template void MatrixBase::AddVecToCols(const double alpha, const VectorBase &v); - -//Explicit instantiation of the classes -//Apparently, it seems to be necessary that the instantiation -//happens at the end of the file. Otherwise, not all the member -//functions will get instantiated. +*/ +// Explicit instantiation of the classes +// Apparently, it seems to be necessary that the instantiation +// happens at the end of the file. Otherwise, not all the member +// functions will get instantiated. template class Matrix; template class Matrix; @@ -3100,4 +3196,4 @@ template class MatrixBase; template class SubMatrix; template class SubMatrix; -} // namespace kaldi +} // namespace kaldi diff --git a/runtime/engine/common/matrix/kaldi-matrix.h b/runtime/engine/common/matrix/kaldi-matrix.h new file mode 100644 index 00000000..d614f36f --- /dev/null +++ b/runtime/engine/common/matrix/kaldi-matrix.h @@ -0,0 +1,906 @@ +// matrix/kaldi-matrix.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University; Petr Schwarz; Yanmin Qian; +// Karel Vesely; Go Vivace Inc.; Haihua Xu +// 2017 Shiyin Kang +// 2019 Yiwen Shao + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_MATRIX_H_ +#define KALDI_MATRIX_KALDI_MATRIX_H_ 1 + +#include + +#include "matrix/matrix-common.h" + +namespace kaldi { + +/// @{ \addtogroup matrix_funcs_scalar + +/// \addtogroup matrix_group +/// @{ + +/// Base class which provides matrix operations not involving resizing +/// or allocation. Classes Matrix and SubMatrix inherit from it and take care +/// of allocation and resizing. +template +class MatrixBase { + public: + // so this child can access protected members of other instances. + friend class Matrix; + friend class SubMatrix; + // friend declarations for CUDA matrices (see ../cudamatrix/) + + /// Returns number of rows (or zero for empty matrix). + inline MatrixIndexT NumRows() const { return num_rows_; } + + /// Returns number of columns (or zero for empty matrix). + inline MatrixIndexT NumCols() const { return num_cols_; } + + /// Stride (distance in memory between each row). Will be >= NumCols. + inline MatrixIndexT Stride() const { return stride_; } + + /// Returns size in bytes of the data held by the matrix. + size_t SizeInBytes() const { + return static_cast(num_rows_) * static_cast(stride_) * + sizeof(Real); + } + + /// Gives pointer to raw data (const). + inline const Real *Data() const { return data_; } + + /// Gives pointer to raw data (non-const). + inline Real *Data() { return data_; } + + /// Returns pointer to data for one row (non-const) + inline Real *RowData(MatrixIndexT i) { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return data_ + i * stride_; + } + + /// Returns pointer to data for one row (const) + inline const Real *RowData(MatrixIndexT i) const { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return data_ + i * stride_; + } + + /// Indexing operator, non-const + /// (only checks sizes if compiled with -DKALDI_PARANOID) + inline Real &operator()(MatrixIndexT r, MatrixIndexT c) { + KALDI_PARANOID_ASSERT( + static_cast(r) < + static_cast(num_rows_) && + static_cast(c) < + static_cast(num_cols_)); + return *(data_ + r * stride_ + c); + } + /// Indexing operator, provided for ease of debugging (gdb doesn't work + /// with parenthesis operator). + Real &Index(MatrixIndexT r, MatrixIndexT c) { return (*this)(r, c); } + + /// Indexing operator, const + /// (only checks sizes if compiled with -DKALDI_PARANOID) + inline const Real operator()(MatrixIndexT r, MatrixIndexT c) const { + KALDI_PARANOID_ASSERT( + static_cast(r) < + static_cast(num_rows_) && + static_cast(c) < + static_cast(num_cols_)); + return *(data_ + r * stride_ + c); + } + + /* Basic setting-to-special values functions. */ + + /// Sets matrix to zero. + void SetZero(); + /// Sets all elements to a specific value. + void Set(Real); + /// Sets to zero, except ones along diagonal [for non-square matrices too] + + /// Copy given matrix. (no resize is done). + template + void CopyFromMat(const MatrixBase &M, + MatrixTransposeType trans = kNoTrans); + + /// Copy from compressed matrix. + // void CopyFromMat(const CompressedMatrix &M); + + /// Copy given tpmatrix. (no resize is done). + // template + // void CopyFromTp(const TpMatrix &M, + // MatrixTransposeType trans = kNoTrans); + + /// Copy from CUDA matrix. Implemented in ../cudamatrix/cu-matrix.h + // template + // void CopyFromMat(const CuMatrixBase &M, + // MatrixTransposeType trans = kNoTrans); + + /// This function has two modes of operation. If v.Dim() == NumRows() * + /// NumCols(), then treats the vector as a row-by-row concatenation of a + /// matrix and copies to *this. + /// if v.Dim() == NumCols(), it sets each row of *this to a copy of v. + void CopyRowsFromVec(const VectorBase &v); + + /// This version of CopyRowsFromVec is implemented in + /// ../cudamatrix/cu-vector.cc + // void CopyRowsFromVec(const CuVectorBase &v); + + template + void CopyRowsFromVec(const VectorBase &v); + + /// Copies vector into matrix, column-by-column. + /// Note that rv.Dim() must either equal NumRows()*NumCols() or NumRows(); + /// this has two modes of operation. + void CopyColsFromVec(const VectorBase &v); + + /// Copy vector into specific column of matrix. + void CopyColFromVec(const VectorBase &v, const MatrixIndexT col); + /// Copy vector into specific row of matrix. + void CopyRowFromVec(const VectorBase &v, const MatrixIndexT row); + /// Copy vector into diagonal of matrix. + void CopyDiagFromVec(const VectorBase &v); + + /* Accessing of sub-parts of the matrix. */ + + /// Return specific row of matrix [const]. + inline const SubVector Row(MatrixIndexT i) const { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return SubVector(data_ + (i * stride_), NumCols()); + } + + /// Return specific row of matrix. + inline SubVector Row(MatrixIndexT i) { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return SubVector(data_ + (i * stride_), NumCols()); + } + + /// Return a sub-part of matrix. + inline SubMatrix Range(const MatrixIndexT row_offset, + const MatrixIndexT num_rows, + const MatrixIndexT col_offset, + const MatrixIndexT num_cols) const { + return SubMatrix( + *this, row_offset, num_rows, col_offset, num_cols); + } + inline SubMatrix RowRange(const MatrixIndexT row_offset, + const MatrixIndexT num_rows) const { + return SubMatrix(*this, row_offset, num_rows, 0, num_cols_); + } + inline SubMatrix ColRange(const MatrixIndexT col_offset, + const MatrixIndexT num_cols) const { + return SubMatrix(*this, 0, num_rows_, col_offset, num_cols); + } + + /* + /// Returns sum of all elements in matrix. + Real Sum() const; + /// Returns trace of matrix. + Real Trace(bool check_square = true) const; + // If check_square = true, will crash if matrix is not square. + + /// Returns maximum element of matrix. + Real Max() const; + /// Returns minimum element of matrix. + Real Min() const; + + /// Element by element multiplication with a given matrix. + void MulElements(const MatrixBase &A); + + /// Divide each element by the corresponding element of a given matrix. + void DivElements(const MatrixBase &A); + + /// Multiply each element with a scalar value. + void Scale(Real alpha); + + /// Set, element-by-element, *this = max(*this, A) + void Max(const MatrixBase &A); + /// Set, element-by-element, *this = min(*this, A) + void Min(const MatrixBase &A); + + /// Equivalent to (*this) = (*this) * diag(scale). Scaling + /// each column by a scalar taken from that dimension of the vector. + void MulColsVec(const VectorBase &scale); + + /// Equivalent to (*this) = diag(scale) * (*this). Scaling + /// each row by a scalar taken from that dimension of the vector. + void MulRowsVec(const VectorBase &scale); + + /// Divide each row into src.NumCols() equal groups, and then scale i'th + row's + /// j'th group of elements by src(i, j). Requires src.NumRows() == + /// this->NumRows() and this->NumCols() % src.NumCols() == 0. + void MulRowsGroupMat(const MatrixBase &src); + + /// Returns logdet of matrix. + Real LogDet(Real *det_sign = NULL) const; + + /// matrix inverse. + /// if inverse_needed = false, will fill matrix with garbage. + /// (only useful if logdet wanted). + void Invert(Real *log_det = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + /// matrix inverse [double]. + /// if inverse_needed = false, will fill matrix with garbage + /// (only useful if logdet wanted). + /// Does inversion in double precision even if matrix was not double. + void InvertDouble(Real *LogDet = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + */ + /// Inverts all the elements of the matrix + void InvertElements(); + /* + /// Transpose the matrix. This one is only + /// applicable to square matrices (the one in the + /// Matrix child class works also for non-square. + void Transpose(); + + */ + /// Copies column r from column indices[r] of src. + /// As a special case, if indexes[i] == -1, sets column i to zero. + /// all elements of "indices" must be in [-1, src.NumCols()-1], + /// and src.NumRows() must equal this.NumRows() + void CopyCols(const MatrixBase &src, const MatrixIndexT *indices); + + /// Copies row r from row indices[r] of src (does nothing + /// As a special case, if indexes[i] == -1, sets row i to zero. + /// all elements of "indices" must be in [-1, src.NumRows()-1], + /// and src.NumCols() must equal this.NumCols() + void CopyRows(const MatrixBase &src, const MatrixIndexT *indices); + + /// Add column indices[r] of src to column r. + /// As a special case, if indexes[i] == -1, skip column i + /// indices.size() must equal this->NumCols(), + /// all elements of "reorder" must be in [-1, src.NumCols()-1], + /// and src.NumRows() must equal this.NumRows() + // void AddCols(const MatrixBase &src, + // const MatrixIndexT *indices); + + /// Copies row r of this matrix from an array of floats at the location + /// given + /// by src[r]. If any src[r] is NULL then this.Row(r) will be set to zero. + /// Note: we are using "pointer to const pointer to const object" for "src", + /// because we may create "src" by calling Data() of const CuArray + void CopyRows(const Real *const *src); + + /// Copies row r of this matrix to the array of floats at the location given + /// by dst[r]. If dst[r] is NULL, does not copy anywhere. Requires that + /// none + /// of the memory regions pointed to by the pointers in "dst" overlap (e.g. + /// none of the pointers should be the same). + void CopyToRows(Real *const *dst) const; + + /// Does for each row r, this.Row(r) += alpha * src.row(indexes[r]). + /// If indexes[r] < 0, does not add anything. all elements of "indexes" must + /// be in [-1, src.NumRows()-1], and src.NumCols() must equal + /// this.NumCols(). + // void AddRows(Real alpha, + // const MatrixBase &src, + // const MatrixIndexT *indexes); + + /// Does for each row r, this.Row(r) += alpha * src[r], treating src[r] as + /// the + /// beginning of a region of memory representing a vector of floats, of the + /// same length as this.NumCols(). If src[r] is NULL, does not add anything. + // void AddRows(Real alpha, const Real *const *src); + + /// For each row r of this matrix, adds it (times alpha) to the array of + /// floats at the location given by dst[r]. If dst[r] is NULL, does not do + /// anything for that row. Requires that none of the memory regions pointed + /// to by the pointers in "dst" overlap (e.g. none of the pointers should be + /// the same). + // void AddToRows(Real alpha, Real *const *dst) const; + + /// For each row i of *this, adds this->Row(i) to + /// dst->Row(indexes(i)) if indexes(i) >= 0, else do nothing. + /// Requires that all the indexes[i] that are >= 0 + /// be distinct, otherwise the behavior is undefined. + // void AddToRows(Real alpha, + // const MatrixIndexT *indexes, + // MatrixBase *dst) const; + /* + inline void ApplyPow(Real power) { + this -> Pow(*this, power); + } + + + inline void ApplyPowAbs(Real power, bool include_sign=false) { + this -> PowAbs(*this, power, include_sign); + } + + inline void ApplyHeaviside() { + this -> Heaviside(*this); + } + + inline void ApplyFloor(Real floor_val) { + this -> Floor(*this, floor_val); + } + + inline void ApplyCeiling(Real ceiling_val) { + this -> Ceiling(*this, ceiling_val); + } + + inline void ApplyExp() { + this -> Exp(*this); + } + + inline void ApplyExpSpecial() { + this -> ExpSpecial(*this); + } + + inline void ApplyExpLimited(Real lower_limit, Real upper_limit) { + this -> ExpLimited(*this, lower_limit, upper_limit); + } + + inline void ApplyLog() { + this -> Log(*this); + } + */ + /// Eigenvalue Decomposition of a square NxN matrix into the form (*this) = + /// P D + /// P^{-1}. Be careful: the relationship of D to the eigenvalues we output + /// is + /// slightly complicated, due to the need for P to be real. In the + /// symmetric + /// case D is diagonal and real, but in + /// the non-symmetric case there may be complex-conjugate pairs of + /// eigenvalues. + /// In this case, for the equation (*this) = P D P^{-1} to hold, D must + /// actually + /// be block diagonal, with 2x2 blocks corresponding to any such pairs. If + /// a + /// pair is lambda +- i*mu, D will have a corresponding 2x2 block + /// [lambda, mu; -mu, lambda]. + /// Note that if the input matrix (*this) is non-invertible, P may not be + /// invertible + /// so in this case instead of the equation (*this) = P D P^{-1} holding, we + /// have + /// instead (*this) P = P D. + /// + /// The non-member function CreateEigenvalueMatrix creates D from eigs_real + /// and eigs_imag. + // void Eig(MatrixBase *P, + // VectorBase *eigs_real, + // VectorBase *eigs_imag) const; + + /// The Power method attempts to take the matrix to a power using a method + /// that + /// works in general for fractional and negative powers. The input matrix + /// must + /// be invertible and have reasonable condition (or we don't guarantee the + /// results. The method is based on the eigenvalue decomposition. It will + /// return false and leave the matrix unchanged, if at entry the matrix had + /// real negative eigenvalues (or if it had zero eigenvalues and the power + /// was + /// negative). + // bool Power(Real pow); + + /** Singular value decomposition + Major limitations: + For nonsquare matrices, we assume m>=n (NumRows >= NumCols), and we + return + the "skinny" Svd, i.e. the matrix in the middle is diagonal, and the + one on the left is rectangular. + + In Svd, *this = U*diag(S)*Vt. + Null pointers for U and/or Vt at input mean we do not want that output. + We + expect that S.Dim() == m, U is either NULL or m by n, + and v is either NULL or n by n. + The singular values are not sorted (use SortSvd for that). */ + // void DestructiveSvd(VectorBase *s, MatrixBase *U, + // MatrixBase *Vt); // Destroys calling matrix. + + /// Compute SVD (*this) = U diag(s) Vt. Note that the V in the call is + /// already + /// transposed; the normal formulation is U diag(s) V^T. + /// Null pointers for U or V mean we don't want that output (this saves + /// compute). The singular values are not sorted (use SortSvd for that). + // void Svd(VectorBase *s, MatrixBase *U, + // MatrixBase *Vt) const; + /// Compute SVD but only retain the singular values. + // void Svd(VectorBase *s) const { Svd(s, NULL, NULL); } + + + /// Returns smallest singular value. + // Real MinSingularValue() const { + // Vector tmp(std::min(NumRows(), NumCols())); + // Svd(&tmp); + // return tmp.Min(); + //} + + // void TestUninitialized() const; // This function is designed so that if + // any element + // if the matrix is uninitialized memory, valgrind will complain. + + /// Returns condition number by computing Svd. Works even if cols > rows. + /// Returns infinity if all singular values are zero. + /* + Real Cond() const; + + /// Returns true if matrix is Symmetric. + bool IsSymmetric(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if matrix is Diagonal. + bool IsDiagonal(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if the matrix is all zeros, except for ones on diagonal. + (it + /// does not have to be square). More specifically, this function returns + /// false if for any i, j, (*this)(i, j) differs by more than cutoff from + the + /// expression (i == j ? 1 : 0). + bool IsUnit(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if matrix is all zeros. + bool IsZero(Real cutoff = 1.0e-05) const; // replace magic number + + /// Frobenius norm, which is the sqrt of sum of square elements. Same as + Schatten 2-norm, + /// or just "2-norm". + Real FrobeniusNorm() const; + + /// Returns true if ((*this)-other).FrobeniusNorm() + /// <= tol * (*this).FrobeniusNorm(). + bool ApproxEqual(const MatrixBase &other, float tol = 0.01) const; + + /// Tests for exact equality. It's usually preferable to use ApproxEqual. + bool Equal(const MatrixBase &other) const; + + /// largest absolute value. + Real LargestAbsElem() const; // largest absolute value. + + /// Returns log(sum(exp())) without exp overflow + /// If prune > 0.0, it uses a pruning beam, discarding + /// terms less than (max - prune). Note: in future + /// we may change this so that if prune = 0.0, it takes + /// the max, so use -1 if you don't want to prune. + Real LogSumExp(Real prune = -1.0) const; + + /// Apply soft-max to the collection of all elements of the + /// matrix and return normalizer (log sum of exponentials). + Real ApplySoftMax(); + + /// Set each element to the sigmoid of the corresponding element of "src". + void Sigmoid(const MatrixBase &src); + + /// Sets each element to the Heaviside step function (x > 0 ? 1 : 0) of the + /// corresponding element in "src". Note: in general you can make different + /// choices for x = 0, but for now please leave it as it (i.e. returning + zero) + /// because it affects the RectifiedLinearComponent in the neural net code. + void Heaviside(const MatrixBase &src); + + void Exp(const MatrixBase &src); + + void Pow(const MatrixBase &src, Real power); + + void Log(const MatrixBase &src); + + /// Apply power to the absolute value of each element. + /// If include_sign is true, the result will be multiplied with + /// the sign of the input value. + /// If the power is negative and the input to the power is zero, + /// The output will be set zero. If include_sign is true, it will + /// multiply the result by the sign of the input. + void PowAbs(const MatrixBase &src, Real power, bool + include_sign=false); + + void Floor(const MatrixBase &src, Real floor_val); + + void Ceiling(const MatrixBase &src, Real ceiling_val); + + /// For each element x of the matrix, set it to + /// (x < 0 ? exp(x) : x + 1). This function is used + /// in our RNNLM training. + void ExpSpecial(const MatrixBase &src); + + /// This is equivalent to running: + /// Floor(src, lower_limit); + /// Ceiling(src, upper_limit); + /// Exp(src) + void ExpLimited(const MatrixBase &src, Real lower_limit, Real + upper_limit); + + /// Set each element to y = log(1 + exp(x)) + void SoftHinge(const MatrixBase &src); + + /// Apply the function y(i) = (sum_{j = i*G}^{(i+1)*G-1} x_j^(power))^(1 / + p). + /// Requires src.NumRows() == this->NumRows() and src.NumCols() % + this->NumCols() == 0. + void GroupPnorm(const MatrixBase &src, Real power); + + /// Calculate derivatives for the GroupPnorm function above... + /// if "input" is the input to the GroupPnorm function above (i.e. the "src" + variable), + /// and "output" is the result of the computation (i.e. the "this" of that + function + /// call), and *this has the same dimension as "input", then it sets each + element + /// of *this to the derivative d(output-elem)/d(input-elem) for each element + of "input", where + /// "output-elem" is whichever element of output depends on that input + element. + void GroupPnormDeriv(const MatrixBase &input, const MatrixBase + &output, + Real power); + + /// Apply the function y(i) = (max_{j = i*G}^{(i+1)*G-1} x_j + /// Requires src.NumRows() == this->NumRows() and src.NumCols() % + this->NumCols() == 0. + void GroupMax(const MatrixBase &src); + + /// Calculate derivatives for the GroupMax function above, where + /// "input" is the input to the GroupMax function above (i.e. the "src" + variable), + /// and "output" is the result of the computation (i.e. the "this" of that + function + /// call), and *this must have the same dimension as "input". Each element + /// of *this will be set to 1 if the corresponding input equals the output + of + /// the group, and 0 otherwise. The equals the function derivative where it + is + /// defined (it's not defined where multiple inputs in the group are equal + to the output). + void GroupMaxDeriv(const MatrixBase &input, const MatrixBase + &output); + + /// Set each element to the tanh of the corresponding element of "src". + void Tanh(const MatrixBase &src); + + // Function used in backpropagating derivatives of the sigmoid function: + // element-by-element, set *this = diff * value * (1.0 - value). + void DiffSigmoid(const MatrixBase &value, + const MatrixBase &diff); + + // Function used in backpropagating derivatives of the tanh function: + // element-by-element, set *this = diff * (1.0 - value^2). + void DiffTanh(const MatrixBase &value, + const MatrixBase &diff); + */ + /** Uses Svd to compute the eigenvalue decomposition of a symmetric positive + * semi-definite matrix: (*this) = rP * diag(rS) * rP^T, with rP an + * orthogonal matrix so rP^{-1} = rP^T. Throws exception if input was not + * positive semi-definite (check_thresh controls how stringent the check is; + * set it to 2 to ensure it won't ever complain, but it will zero out + * negative + * dimensions in your matrix. + * + * Caution: if you want the eigenvalues, it may make more sense to convert + * to + * SpMatrix and use Eig() function there, which uses eigenvalue + * decomposition + * directly rather than SVD. + */ + + /// stream read. + /// Use instead of stream<<*this, if you want to add to existing contents. + // Will throw exception on failure. + void Read(std::istream &in, bool binary); + /// write to stream. + void Write(std::ostream &out, bool binary) const; + + // Below is internal methods for Svd, user does not have to know about this. + protected: + /// Initializer, callable only from child. + explicit MatrixBase(Real *data, + MatrixIndexT cols, + MatrixIndexT rows, + MatrixIndexT stride) + : data_(data), num_cols_(cols), num_rows_(rows), stride_(stride) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + + /// Initializer, callable only from child. + /// Empty initializer, for un-initialized matrix. + explicit MatrixBase() : data_(NULL) { KALDI_ASSERT_IS_FLOATING_TYPE(Real); } + + // Make sure pointers to MatrixBase cannot be deleted. + ~MatrixBase() {} + + /// A workaround that allows SubMatrix to get a pointer to non-const data + /// for const Matrix. Unfortunately C++ does not allow us to declare a + /// "public const" inheritance or anything like that, so it would require + /// a lot of work to make the SubMatrix class totally const-correct-- + /// we would have to override many of the Matrix functions. + inline Real *Data_workaround() const { return data_; } + + /// data memory area + Real *data_; + + /// these attributes store the real matrix size as it is stored in memory + /// including memalignment + MatrixIndexT num_cols_; /// < Number of columns + MatrixIndexT num_rows_; /// < Number of rows + /** True number of columns for the internal matrix. This number may differ + * from num_cols_ as memory alignment might be used. */ + MatrixIndexT stride_; + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(MatrixBase); +}; + +/// A class for storing matrices. +template +class Matrix : public MatrixBase { + public: + /// Empty constructor. + Matrix(); + + /// Basic constructor. + Matrix(const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride) + : MatrixBase() { + Resize(r, c, resize_type, stride_type); + } + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(Matrix *other); + + /// Constructor from any MatrixBase. Can also copy with transpose. + /// Allocates new memory. + explicit Matrix(const MatrixBase &M, + MatrixTransposeType trans = kNoTrans); + + /// Same as above, but need to avoid default copy constructor. + Matrix(const Matrix &M); // (cannot make explicit) + + /// Copy constructor: as above, but from another type. + template + explicit Matrix(const MatrixBase &M, + MatrixTransposeType trans = kNoTrans); + + /// Copy constructor taking TpMatrix... + // template + // explicit Matrix(const TpMatrix & M, + // MatrixTransposeType trans = kNoTrans) : MatrixBase() { + // if (trans == kNoTrans) { + // Resize(M.NumRows(), M.NumCols(), kUndefined); + // this->CopyFromTp(M); + //} else { + // Resize(M.NumCols(), M.NumRows(), kUndefined); + // this->CopyFromTp(M, kTrans); + //} + //} + + /// read from stream. + // Unlike one in base, allows resizing. + void Read(std::istream &in, bool binary); + + /// Remove a specified row. + void RemoveRow(MatrixIndexT i); + + /// Transpose the matrix. Works for non-square + /// matrices as well as square ones. + // void Transpose(); + + /// Distructor to free matrices. + ~Matrix() { Destroy(); } + + /// Sets matrix to a specified size (zero is OK as long as both r and c are + /// zero). The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// + /// You can set stride_type to kStrideEqualNumCols to force the stride + /// to equal the number of columns; by default it is set so that the stride + /// in bytes is a multiple of 16. + /// + /// This function takes time proportional to the number of data elements. + void Resize(const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride); + + /// Assignment operator that takes MatrixBase. + Matrix &operator=(const MatrixBase &other) { + if (MatrixBase::NumRows() != other.NumRows() || + MatrixBase::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase::CopyFromMat(other); + return *this; + } + + /// Assignment operator. Needed for inclusion in std::vector. + Matrix &operator=(const Matrix &other) { + if (MatrixBase::NumRows() != other.NumRows() || + MatrixBase::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase::CopyFromMat(other); + return *this; + } + + + private: + /// Deallocates memory and sets to empty matrix (dimension 0, 0). + void Destroy(); + + /// Init assumes the current class contents are invalid (i.e. junk or have + /// already been freed), and it sets the matrix to newly allocated memory + /// with + /// the specified number of rows and columns. r == c == 0 is acceptable. + /// The data + /// memory contents will be undefined. + void Init(const MatrixIndexT r, + const MatrixIndexT c, + const MatrixStrideType stride_type); +}; +/// @} end "addtogroup matrix_group" + +/// \addtogroup matrix_funcs_io +/// @{ + +/// A structure containing the HTK header. +/// [TODO: change the style of the variables to Kaldi-compliant] + +template +class SubMatrix : public MatrixBase { + public: + // Initialize a SubMatrix from part of a matrix; this is + // a bit like A(b:c, d:e) in Matlab. + // This initializer is against the proper semantics of "const", since + // SubMatrix can change its contents. It would be hard to implement + // a "const-safe" version of this class. + SubMatrix(const MatrixBase &T, + const MatrixIndexT ro, // row offset, 0 < ro < NumRows() + const MatrixIndexT r, // number of rows, r > 0 + const MatrixIndexT co, // column offset, 0 < co < NumCols() + const MatrixIndexT c); // number of columns, c > 0 + + // This initializer is mostly intended for use in CuMatrix and related + // classes. Be careful! + SubMatrix(Real *data, + MatrixIndexT num_rows, + MatrixIndexT num_cols, + MatrixIndexT stride); + + ~SubMatrix() {} + + /// This type of constructor is needed for Range() to work [in Matrix base + /// class]. Cannot make it explicit. + SubMatrix(const SubMatrix &other) + : MatrixBase( + other.data_, other.num_cols_, other.num_rows_, other.stride_) {} + + private: + /// Disallow assignment. + SubMatrix &operator=(const SubMatrix &other); +}; + +/// @} End of "addtogroup matrix_funcs_io". + +/// \addtogroup matrix_funcs_scalar +/// @{ + +// Some declarations. These are traces of products. + +/************************ +template +bool ApproxEqual(const MatrixBase &A, + const MatrixBase &B, Real tol = 0.01) { + return A.ApproxEqual(B, tol); +} + +template +inline void AssertEqual(const MatrixBase &A, const MatrixBase &B, + float tol = 0.01) { + KALDI_ASSERT(A.ApproxEqual(B, tol)); +} + +/// Returns trace of matrix. +template +double TraceMat(const MatrixBase &A) { return A.Trace(); } + + +/// Returns tr(A B C) +template +Real TraceMatMatMat(const MatrixBase &A, MatrixTransposeType transA, + const MatrixBase &B, MatrixTransposeType transB, + const MatrixBase &C, MatrixTransposeType transC); + +/// Returns tr(A B C D) +template +Real TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, + const MatrixBase &B, MatrixTransposeType transB, + const MatrixBase &C, MatrixTransposeType transC, + const MatrixBase &D, MatrixTransposeType transD); + +/// @} end "addtogroup matrix_funcs_scalar" + + +/// \addtogroup matrix_funcs_misc +/// @{ + + +/// Function to ensure that SVD is sorted. This function is made as generic as +/// possible, to be applicable to other types of problems. s->Dim() should be +/// the same as U->NumCols(), and we sort s from greatest to least absolute +/// value (if sort_on_absolute_value == true) or greatest to least value +/// otherwise, moving the columns of U, if it exists, and the rows of Vt, if it +/// exists, around in the same way. Note: the "absolute value" part won't +matter +/// if this is an actual SVD, since singular values are non-negative. +template void SortSvd(VectorBase *s, MatrixBase *U, + MatrixBase* Vt = NULL, + bool sort_on_absolute_value = true); + +/// Creates the eigenvalue matrix D that is part of the decomposition used +Matrix::Eig. +/// D will be block-diagonal with blocks of size 1 (for real eigenvalues) or 2x2 +/// for complex pairs. If a complex pair is lambda +- i*mu, D will have a +corresponding +/// 2x2 block [lambda, mu; -mu, lambda]. +/// This function will throw if any complex eigenvalues are not in complex +conjugate +/// pairs (or the members of such pairs are not consecutively numbered). +template +void CreateEigenvalueMatrix(const VectorBase &real, const VectorBase +&imag, + MatrixBase *D); + +/// The following function is used in Matrix::Power, and separately tested, so +we +/// declare it here mainly for the testing code to see. It takes a complex +value to +/// a power using a method that will work for noninteger powers (but will fail +if the +/// complex value is real and negative). +template +bool AttemptComplexPower(Real *x_re, Real *x_im, Real power); + +**********/ + +/// @} end of addtogroup matrix_funcs_misc + +/// \addtogroup matrix_funcs_io +/// @{ +template +std::ostream &operator<<(std::ostream &Out, const MatrixBase &M); + +template +std::istream &operator>>(std::istream &In, MatrixBase &M); + +// The Matrix read allows resizing, so we override the MatrixBase one. +template +std::istream &operator>>(std::istream &In, Matrix &M); + +template +bool SameDim(const MatrixBase &M, const MatrixBase &N) { + return (M.NumRows() == N.NumRows() && M.NumCols() == N.NumCols()); +} + +/// @} end of \addtogroup matrix_funcs_io + + +} // namespace kaldi + + +// we need to include the implementation and some +// template specializations. +#include "matrix/kaldi-matrix-inl.h" + + +#endif // KALDI_MATRIX_KALDI_MATRIX_H_ diff --git a/speechx/speechx/kaldi/matrix/kaldi-vector-inl.h b/runtime/engine/common/matrix/kaldi-vector-inl.h similarity index 63% rename from speechx/speechx/kaldi/matrix/kaldi-vector-inl.h rename to runtime/engine/common/matrix/kaldi-vector-inl.h index c3a4f52f..b3075e59 100644 --- a/speechx/speechx/kaldi/matrix/kaldi-vector-inl.h +++ b/runtime/engine/common/matrix/kaldi-vector-inl.h @@ -26,32 +26,33 @@ namespace kaldi { -template -std::ostream & operator << (std::ostream &os, const VectorBase &rv) { - rv.Write(os, false); - return os; +template +std::ostream &operator<<(std::ostream &os, const VectorBase &rv) { + rv.Write(os, false); + return os; } -template -std::istream &operator >> (std::istream &is, VectorBase &rv) { - rv.Read(is, false); - return is; +template +std::istream &operator>>(std::istream &is, VectorBase &rv) { + rv.Read(is, false); + return is; } -template -std::istream &operator >> (std::istream &is, Vector &rv) { - rv.Read(is, false); - return is; +template +std::istream &operator>>(std::istream &is, Vector &rv) { + rv.Read(is, false); + return is; } -template<> -template<> -void VectorBase::AddVec(const float alpha, const VectorBase &rv); +// template<> +// template<> +// void VectorBase::AddVec(const float alpha, const VectorBase +// &rv); -template<> -template<> -void VectorBase::AddVec(const double alpha, - const VectorBase &rv); +// template<> +// template<> +// void VectorBase::AddVec(const double alpha, +// const VectorBase &rv); } // namespace kaldi diff --git a/runtime/engine/common/matrix/kaldi-vector.cc b/runtime/engine/common/matrix/kaldi-vector.cc new file mode 100644 index 00000000..3ab9a7ff --- /dev/null +++ b/runtime/engine/common/matrix/kaldi-vector.cc @@ -0,0 +1,1239 @@ +// matrix/kaldi-vector.cc + +// Copyright 2009-2011 Microsoft Corporation; Lukas Burget; +// Saarland University; Go Vivace Inc.; Ariya Rastrow; +// Petr Schwarz; Yanmin Qian; Jan Silovsky; +// Haihua Xu; Wei Shi +// 2015 Guoguo Chen +// 2017 Daniel Galvez +// 2019 Yiwen Shao + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "matrix/kaldi-vector.h" + +#include +#include + +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +template +inline void Vector::Init(const MatrixIndexT dim) { + KALDI_ASSERT(dim >= 0); + if (dim == 0) { + this->dim_ = 0; + this->data_ = NULL; + return; + } + MatrixIndexT size; + void *data; + void *free_data; + + size = dim * sizeof(Real); + + if ((data = KALDI_MEMALIGN(16, size, &free_data)) != NULL) { + this->data_ = static_cast(data); + this->dim_ = dim; + } else { + throw std::bad_alloc(); + } +} + + +template +void Vector::Resize(const MatrixIndexT dim, + MatrixResizeType resize_type) { + // the next block uses recursion to handle what we have to do if + // resize_type == kCopyData. + if (resize_type == kCopyData) { + if (this->data_ == NULL || dim == 0) + resize_type = kSetZero; // nothing to copy. + else if (this->dim_ == dim) { + return; + } // nothing to do. + else { + // set tmp to a vector of the desired size. + Vector tmp(dim, kUndefined); + if (dim > this->dim_) { + memcpy(tmp.data_, this->data_, sizeof(Real) * this->dim_); + memset(tmp.data_ + this->dim_, + 0, + sizeof(Real) * (dim - this->dim_)); + } else { + memcpy(tmp.data_, this->data_, sizeof(Real) * dim); + } + tmp.Swap(this); + // and now let tmp go out of scope, deleting what was in *this. + return; + } + } + // At this point, resize_type == kSetZero or kUndefined. + + if (this->data_ != NULL) { + if (this->dim_ == dim) { + if (resize_type == kSetZero) this->SetZero(); + return; + } else { + Destroy(); + } + } + Init(dim); + if (resize_type == kSetZero) this->SetZero(); +} + + +/// Copy data from another vector +template +void VectorBase::CopyFromVec(const VectorBase &v) { + KALDI_ASSERT(Dim() == v.Dim()); + if (data_ != v.data_) { + std::memcpy(this->data_, v.data_, dim_ * sizeof(Real)); + } +} + +/* +template +template +void VectorBase::CopyFromPacked(const PackedMatrix& M) { + SubVector v(M); + this->CopyFromVec(v); +} +// instantiate the template. +template void VectorBase::CopyFromPacked(const PackedMatrix +&other); +template void VectorBase::CopyFromPacked(const PackedMatrix +&other); +template void VectorBase::CopyFromPacked(const PackedMatrix +&other); +template void VectorBase::CopyFromPacked(const PackedMatrix +&other); + +/// Load data into the vector +template +void VectorBase::CopyFromPtr(const Real *data, MatrixIndexT sz) { + KALDI_ASSERT(dim_ == sz); + std::memcpy(this->data_, data, Dim() * sizeof(Real)); +}*/ + +template +template +void VectorBase::CopyFromVec(const VectorBase &other) { + KALDI_ASSERT(dim_ == other.Dim()); + Real *__restrict__ ptr = data_; + const OtherReal *__restrict__ other_ptr = other.Data(); + for (MatrixIndexT i = 0; i < dim_; i++) ptr[i] = other_ptr[i]; +} + +template void VectorBase::CopyFromVec(const VectorBase &other); +template void VectorBase::CopyFromVec(const VectorBase &other); + +// Remove element from the vector. The vector is not reallocated +template +void Vector::RemoveElement(MatrixIndexT i) { + KALDI_ASSERT(i < this->dim_ && "Access out of vector"); + for (MatrixIndexT j = i + 1; j < this->dim_; j++) + this->data_[j - 1] = this->data_[j]; + this->dim_--; +} + + +/// Deallocates memory and sets object to empty vector. +template +void Vector::Destroy() { + /// we need to free the data block if it was defined + if (this->data_ != NULL) KALDI_MEMALIGN_FREE(this->data_); + this->data_ = NULL; + this->dim_ = 0; +} + +template +void VectorBase::SetZero() { + std::memset(data_, 0, dim_ * sizeof(Real)); +} + +template +bool VectorBase::IsZero(Real cutoff) const { + Real abs_max = 0.0; + for (MatrixIndexT i = 0; i < Dim(); i++) + abs_max = std::max(std::abs(data_[i]), abs_max); + return (abs_max <= cutoff); +} + +/* +template +void VectorBase::SetRandn() { + kaldi::RandomState rstate; + MatrixIndexT last = (Dim() % 2 == 1) ? Dim() - 1 : Dim(); + for (MatrixIndexT i = 0; i < last; i += 2) { + kaldi::RandGauss2(data_ + i, data_ + i + 1, &rstate); + } + if (Dim() != last) data_[last] = static_cast(kaldi::RandGauss(&rstate)); +} + +template +void VectorBase::SetRandUniform() { + kaldi::RandomState rstate; + for (MatrixIndexT i = 0; i < Dim(); i++) { + *(data_+i) = RandUniform(&rstate); + } +} + +template +MatrixIndexT VectorBase::RandCategorical() const { + kaldi::RandomState rstate; + Real sum = this->Sum(); + KALDI_ASSERT(this->Min() >= 0.0 && sum > 0.0); + Real r = RandUniform(&rstate) * sum; + Real *data = this->data_; + MatrixIndexT dim = this->dim_; + Real running_sum = 0.0; + for (MatrixIndexT i = 0; i < dim; i++) { + running_sum += data[i]; + if (r < running_sum) return i; + } + return dim_ - 1; // Should only happen if RandUniform() + // returns exactly 1, or due to roundoff. +}*/ + +template +void VectorBase::Set(Real f) { + // Why not use memset here? + // The basic unit of memset is a byte. + // If f != 0 and sizeof(Real) > 1, then we cannot use memset. + if (f == 0) { + this->SetZero(); // calls std::memset + } else { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = f; + } + } +} + +template +void VectorBase::CopyRowsFromMat(const MatrixBase &mat) { + KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); + + Real *inc_data = data_; + const MatrixIndexT cols = mat.NumCols(), rows = mat.NumRows(); + + if (mat.Stride() == mat.NumCols()) { + memcpy(inc_data, mat.Data(), cols * rows * sizeof(Real)); + } else { + for (MatrixIndexT i = 0; i < rows; i++) { + // copy the data to the propper position + memcpy(inc_data, mat.RowData(i), cols * sizeof(Real)); + // set new copy position + inc_data += cols; + } + } +} + +template +template +void VectorBase::CopyRowsFromMat(const MatrixBase &mat) { + KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); + Real *vec_data = data_; + const MatrixIndexT cols = mat.NumCols(), rows = mat.NumRows(); + + for (MatrixIndexT i = 0; i < rows; i++) { + const OtherReal *mat_row = mat.RowData(i); + for (MatrixIndexT j = 0; j < cols; j++) { + vec_data[j] = static_cast(mat_row[j]); + } + vec_data += cols; + } +} + +template void VectorBase::CopyRowsFromMat(const MatrixBase &mat); +template void VectorBase::CopyRowsFromMat(const MatrixBase &mat); + + +template +void VectorBase::CopyColsFromMat(const MatrixBase &mat) { + KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); + + Real *inc_data = data_; + const MatrixIndexT cols = mat.NumCols(), rows = mat.NumRows(), + stride = mat.Stride(); + const Real *mat_inc_data = mat.Data(); + + for (MatrixIndexT i = 0; i < cols; i++) { + for (MatrixIndexT j = 0; j < rows; j++) { + inc_data[j] = mat_inc_data[j * stride]; + } + mat_inc_data++; + inc_data += rows; + } +} + +template +void VectorBase::CopyRowFromMat(const MatrixBase &mat, + MatrixIndexT row) { + KALDI_ASSERT(row < mat.NumRows()); + KALDI_ASSERT(dim_ == mat.NumCols()); + const Real *mat_row = mat.RowData(row); + memcpy(data_, mat_row, sizeof(Real) * dim_); +} + +template +template +void VectorBase::CopyRowFromMat(const MatrixBase &mat, + MatrixIndexT row) { + KALDI_ASSERT(row < mat.NumRows()); + KALDI_ASSERT(dim_ == mat.NumCols()); + const OtherReal *mat_row = mat.RowData(row); + for (MatrixIndexT i = 0; i < dim_; i++) + data_[i] = static_cast(mat_row[i]); +} + +template void VectorBase::CopyRowFromMat(const MatrixBase &mat, + MatrixIndexT row); +template void VectorBase::CopyRowFromMat(const MatrixBase &mat, + MatrixIndexT row); + +/* +template +template +void VectorBase::CopyRowFromSp(const SpMatrix &sp, MatrixIndexT +row) { + KALDI_ASSERT(row < sp.NumRows()); + KALDI_ASSERT(dim_ == sp.NumCols()); + + const OtherReal *sp_data = sp.Data(); + + sp_data += (row*(row+1)) / 2; // takes us to beginning of this row. + MatrixIndexT i; + for (i = 0; i < row; i++) // copy consecutive elements. + data_[i] = static_cast(*(sp_data++)); + for(; i < dim_; ++i, sp_data += i) + data_[i] = static_cast(*sp_data); +} + +template +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT +row); +template +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT +row); +template +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT +row); +template +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT +row); + +// takes absolute value of the elements to a power. +// Throws exception if could not (but only for power != 1 and power != 2). +template +void VectorBase::ApplyPowAbs(Real power, bool include_sign) { + if (power == 1.0) + for (MatrixIndexT i = 0; i < dim_; i++) + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * std::abs(data_[i]); + if (power == 2.0) { + for (MatrixIndexT i = 0; i < dim_; i++) + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * data_[i] * data_[i]; + } else if (power == 0.5) { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * +std::sqrt(std::abs(data_[i])); + } + } else if (power < 0.0) { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = (data_[i] == 0.0 ? 0.0 : pow(std::abs(data_[i]), power)); + data_[i] *= (include_sign && data_[i] < 0 ? -1 : 1); + if (data_[i] == HUGE_VAL) { // HUGE_VAL is what errno returns on error. + KALDI_ERR << "Could not raise element " << i << "to power " + << power << ": returned value = " << data_[i]; + } + } + } else { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * +pow(std::abs(data_[i]), power); + if (data_[i] == HUGE_VAL) { // HUGE_VAL is what errno returns on error. + KALDI_ERR << "Could not raise element " << i << "to power " + << power << ": returned value = " << data_[i]; + } + } + } +} + +// Computes the p-th norm. Throws exception if could not. +template +Real VectorBase::Norm(Real p) const { + KALDI_ASSERT(p >= 0.0); + Real sum = 0.0; + if (p == 0.0) { + for (MatrixIndexT i = 0; i < dim_; i++) + if (data_[i] != 0.0) sum += 1.0; + return sum; + } else if (p == 1.0) { + for (MatrixIndexT i = 0; i < dim_; i++) + sum += std::abs(data_[i]); + return sum; + } else if (p == 2.0) { + for (MatrixIndexT i = 0; i < dim_; i++) + sum += data_[i] * data_[i]; + return std::sqrt(sum); + } else if (p == std::numeric_limits::infinity()){ + for (MatrixIndexT i = 0; i < dim_; i++) + sum = std::max(sum, std::abs(data_[i])); + return sum; + } else { + Real tmp; + bool ok = true; + for (MatrixIndexT i = 0; i < dim_; i++) { + tmp = pow(std::abs(data_[i]), p); + if (tmp == HUGE_VAL) // HUGE_VAL is what pow returns on error. + ok = false; + sum += tmp; + } + tmp = pow(sum, static_cast(1.0/p)); + KALDI_ASSERT(tmp != HUGE_VAL); // should not happen here. + if (ok) { + return tmp; + } else { + Real maximum = this->Max(), minimum = this->Min(), + max_abs = std::max(maximum, -minimum); + KALDI_ASSERT(max_abs > 0); // Or should not have reached here. + Vector tmp(*this); + tmp.Scale(1.0 / max_abs); + return tmp.Norm(p) * max_abs; + } + } +} + +template +bool VectorBase::ApproxEqual(const VectorBase &other, float tol) +const { + if (dim_ != other.dim_) KALDI_ERR << "ApproxEqual: size mismatch " + << dim_ << " vs. " << other.dim_; + KALDI_ASSERT(tol >= 0.0); + if (tol != 0.0) { + Vector tmp(*this); + tmp.AddVec(-1.0, other); + return (tmp.Norm(2.0) <= static_cast(tol) * this->Norm(2.0)); + } else { // Test for exact equality. + const Real *data = data_; + const Real *other_data = other.data_; + for (MatrixIndexT dim = dim_, i = 0; i < dim; i++) + if (data[i] != other_data[i]) return false; + return true; + } +} + +template +Real VectorBase::Max() const { + Real ans = - std::numeric_limits::infinity(); + const Real *data = data_; + MatrixIndexT i, dim = dim_; + for (i = 0; i + 4 <= dim; i += 4) { + Real a1 = data[i], a2 = data[i+1], a3 = data[i+2], a4 = data[i+3]; + if (a1 > ans || a2 > ans || a3 > ans || a4 > ans) { + Real b1 = (a1 > a2 ? a1 : a2), b2 = (a3 > a4 ? a3 : a4); + if (b1 > ans) ans = b1; + if (b2 > ans) ans = b2; + } + } + for (; i < dim; i++) + if (data[i] > ans) ans = data[i]; + return ans; +} + +template +Real VectorBase::Max(MatrixIndexT *index_out) const { + if (dim_ == 0) KALDI_ERR << "Empty vector"; + Real ans = - std::numeric_limits::infinity(); + MatrixIndexT index = 0; + const Real *data = data_; + MatrixIndexT i, dim = dim_; + for (i = 0; i + 4 <= dim; i += 4) { + Real a1 = data[i], a2 = data[i+1], a3 = data[i+2], a4 = data[i+3]; + if (a1 > ans || a2 > ans || a3 > ans || a4 > ans) { + if (a1 > ans) { ans = a1; index = i; } + if (a2 > ans) { ans = a2; index = i + 1; } + if (a3 > ans) { ans = a3; index = i + 2; } + if (a4 > ans) { ans = a4; index = i + 3; } + } + } + for (; i < dim; i++) + if (data[i] > ans) { ans = data[i]; index = i; } + *index_out = index; + return ans; +} + +template +Real VectorBase::Min() const { + Real ans = std::numeric_limits::infinity(); + const Real *data = data_; + MatrixIndexT i, dim = dim_; + for (i = 0; i + 4 <= dim; i += 4) { + Real a1 = data[i], a2 = data[i+1], a3 = data[i+2], a4 = data[i+3]; + if (a1 < ans || a2 < ans || a3 < ans || a4 < ans) { + Real b1 = (a1 < a2 ? a1 : a2), b2 = (a3 < a4 ? a3 : a4); + if (b1 < ans) ans = b1; + if (b2 < ans) ans = b2; + } + } + for (; i < dim; i++) + if (data[i] < ans) ans = data[i]; + return ans; +} + +template +Real VectorBase::Min(MatrixIndexT *index_out) const { + if (dim_ == 0) KALDI_ERR << "Empty vector"; + Real ans = std::numeric_limits::infinity(); + MatrixIndexT index = 0; + const Real *data = data_; + MatrixIndexT i, dim = dim_; + for (i = 0; i + 4 <= dim; i += 4) { + Real a1 = data[i], a2 = data[i+1], a3 = data[i+2], a4 = data[i+3]; + if (a1 < ans || a2 < ans || a3 < ans || a4 < ans) { + if (a1 < ans) { ans = a1; index = i; } + if (a2 < ans) { ans = a2; index = i + 1; } + if (a3 < ans) { ans = a3; index = i + 2; } + if (a4 < ans) { ans = a4; index = i + 3; } + } + } + for (; i < dim; i++) + if (data[i] < ans) { ans = data[i]; index = i; } + *index_out = index; + return ans; +}*/ + + +template +template +void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col) { + KALDI_ASSERT(col < mat.NumCols()); + KALDI_ASSERT(dim_ == mat.NumRows()); + for (MatrixIndexT i = 0; i < dim_; i++) data_[i] = mat(i, col); + // can't do this very efficiently so don't really bother. could improve this + // though. +} +// instantiate the template above. +template void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col); +template void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col); +template void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col); +template void VectorBase::CopyColFromMat(const MatrixBase &mat, + MatrixIndexT col); + +// template +// void VectorBase::CopyDiagFromMat(const MatrixBase &M) { +// KALDI_ASSERT(dim_ == std::min(M.NumRows(), M.NumCols())); +// cblas_Xcopy(dim_, M.Data(), M.Stride() + 1, data_, 1); +//} + +// template +// void VectorBase::CopyDiagFromPacked(const PackedMatrix &M) { +// KALDI_ASSERT(dim_ == M.NumCols()); +// for (MatrixIndexT i = 0; i < dim_; i++) +// data_[i] = M(i, i); +//// could make this more efficient. +//} + +// template +// Real VectorBase::Sum() const { +//// Do a dot-product with a size-1 array with a stride of 0 to +//// implement sum. This allows us to access SIMD operations in a +//// cross-platform way via your BLAS library. +// Real one(1); +// return cblas_Xdot(dim_, data_, 1, &one, 0); +//} + +// template +// Real VectorBase::SumLog() const { +// double sum_log = 0.0; +// double prod = 1.0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// prod *= data_[i]; +//// Possible future work (arnab): change these magic values to pre-defined +//// constants +// if (prod < 1.0e-10 || prod > 1.0e+10) { +// sum_log += Log(prod); +// prod = 1.0; +//} +//} +// if (prod != 1.0) sum_log += Log(prod); +// return sum_log; +//} + +// template +// void VectorBase::AddRowSumMat(Real alpha, const MatrixBase &M, +// Real beta) { +// KALDI_ASSERT(dim_ == M.NumCols()); +// MatrixIndexT num_rows = M.NumRows(), stride = M.Stride(), dim = dim_; +// Real *data = data_; + +//// implement the function according to a dimension cutoff for computation +/// efficiency +// if (num_rows <= 64) { +// cblas_Xscal(dim, beta, data, 1); +// const Real *m_data = M.Data(); +// for (MatrixIndexT i = 0; i < num_rows; i++, m_data += stride) +// cblas_Xaxpy(dim, alpha, m_data, 1, data, 1); + +//} else { +// Vector ones(M.NumRows()); +// ones.Set(1.0); +// this->AddMatVec(alpha, M, kTrans, ones, beta); +//} +//} + +// template +// void VectorBase::AddColSumMat(Real alpha, const MatrixBase &M, +// Real beta) { +// KALDI_ASSERT(dim_ == M.NumRows()); +// MatrixIndexT num_cols = M.NumCols(); + +//// implement the function according to a dimension cutoff for computation +/// efficiency +// if (num_cols <= 64) { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// double sum = 0.0; +// const Real *src = M.RowData(i); +// for (MatrixIndexT j = 0; j < num_cols; j++) +// sum += src[j]; +// data_[i] = alpha * sum + beta * data_[i]; +//} +//} else { +// Vector ones(M.NumCols()); +// ones.Set(1.0); +// this->AddMatVec(alpha, M, kNoTrans, ones, beta); +//} +//} + +// template +// Real VectorBase::LogSumExp(Real prune) const { +// Real sum; +// if (sizeof(sum) == 8) sum = kLogZeroDouble; +// else sum = kLogZeroFloat; +// Real max_elem = Max(), cutoff; +// if (sizeof(Real) == 4) cutoff = max_elem + kMinLogDiffFloat; +// else cutoff = max_elem + kMinLogDiffDouble; +// if (prune > 0.0 && max_elem - prune > cutoff) // explicit pruning... +// cutoff = max_elem - prune; + +// double sum_relto_max_elem = 0.0; + +// for (MatrixIndexT i = 0; i < dim_; i++) { +// BaseFloat f = data_[i]; +// if (f >= cutoff) +// sum_relto_max_elem += Exp(f - max_elem); +//} +// return max_elem + Log(sum_relto_max_elem); +//} + +// template +// void VectorBase::InvertElements() { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = static_cast(1 / data_[i]); +//} +//} + +// template +// void VectorBase::ApplyLog() { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// if (data_[i] < 0.0) +// KALDI_ERR << "Trying to take log of a negative number."; +// data_[i] = Log(data_[i]); +//} +//} + +// template +// void VectorBase::ApplyLogAndCopy(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.Dim()); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = Log(v(i)); +//} +//} + +// template +// void VectorBase::ApplyExp() { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = Exp(data_[i]); +//} +//} + +// template +// void VectorBase::ApplyAbs() { +// for (MatrixIndexT i = 0; i < dim_; i++) { data_[i] = std::abs(data_[i]); } +//} + +// template +// void VectorBase::Floor(const VectorBase &v, Real floor_val, +// MatrixIndexT *floored_count) { +// KALDI_ASSERT(dim_ == v.dim_); +// if (floored_count == nullptr) { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = std::max(v.data_[i], floor_val); +//} +//} else { +// MatrixIndexT num_floored = 0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// if (v.data_[i] < floor_val) { +// data_[i] = floor_val; +// num_floored++; +//} else { +// data_[i] = v.data_[i]; +//} +//} +//*floored_count = num_floored; +//} +//} + +// template +// void VectorBase::Ceiling(const VectorBase &v, Real ceil_val, +// MatrixIndexT *ceiled_count) { +// KALDI_ASSERT(dim_ == v.dim_); +// if (ceiled_count == nullptr) { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = std::min(v.data_[i], ceil_val); +//} +//} else { +// MatrixIndexT num_changed = 0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// if (v.data_[i] > ceil_val) { +// data_[i] = ceil_val; +// num_changed++; +//} else { +// data_[i] = v.data_[i]; +//} +//} +//*ceiled_count = num_changed; +//} +//} + +// template +// MatrixIndexT VectorBase::ApplyFloor(const VectorBase &floor_vec) +// { +// KALDI_ASSERT(floor_vec.Dim() == dim_); +// MatrixIndexT num_floored = 0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// if (data_[i] < floor_vec(i)) { +// data_[i] = floor_vec(i); +// num_floored++; +//} +//} +// return num_floored; +//} + +// template +// Real VectorBase::ApplySoftMax() { +// Real max = this->Max(), sum = 0.0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// sum += (data_[i] = Exp(data_[i] - max)); +//} +// this->Scale(1.0 / sum); +// return max + Log(sum); +//} + +// template +// Real VectorBase::ApplyLogSoftMax() { +// Real max = this->Max(), sum = 0.0; +// for (MatrixIndexT i = 0; i < dim_; i++) { +// sum += Exp((data_[i] -= max)); +//} +// sum = Log(sum); +// this->Add(-1.0 * sum); +// return max + sum; +//} + +//#ifdef HAVE_MKL +// template<> +// void VectorBase::Tanh(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// vsTanh(dim_, src.data_, data_); +//} +// template<> +// void VectorBase::Tanh(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// vdTanh(dim_, src.data_, data_); +//} +//#else +// template +// void VectorBase::Tanh(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// Real x = src.data_[i]; +// if (x > 0.0) { +// Real inv_expx = Exp(-x); +// x = -1.0 + 2.0 / (1.0 + inv_expx * inv_expx); +//} else { +// Real expx = Exp(x); +// x = 1.0 - 2.0 / (1.0 + expx * expx); +//} +// data_[i] = x; +//} +//} +//#endif + +//#ifdef HAVE_MKL +//// Implementing sigmoid based on tanh. +// template<> +// void VectorBase::Sigmoid(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// this->CopyFromVec(src); +// this->Scale(0.5); +// vsTanh(dim_, data_, data_); +// this->Add(1.0); +// this->Scale(0.5); +//} +// template<> +// void VectorBase::Sigmoid(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// this->CopyFromVec(src); +// this->Scale(0.5); +// vdTanh(dim_, data_, data_); +// this->Add(1.0); +// this->Scale(0.5); +//} +//#else +// template +// void VectorBase::Sigmoid(const VectorBase &src) { +// KALDI_ASSERT(dim_ == src.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// Real x = src.data_[i]; +//// We aim to avoid floating-point overflow here. +// if (x > 0.0) { +// x = 1.0 / (1.0 + Exp(-x)); +//} else { +// Real ex = Exp(x); +// x = ex / (ex + 1.0); +//} +// data_[i] = x; +//} +//} +//#endif + + +// template +// void VectorBase::Add(Real c) { +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] += c; +//} +//} + +// template +// void VectorBase::Scale(Real alpha) { +// cblas_Xscal(dim_, alpha, data_, 1); +//} + +// template +// void VectorBase::MulElements(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] *= v.data_[i]; +//} +//} + +// template // Set each element to y = (x == orig ? changed : +// x). +// void VectorBase::ReplaceValue(Real orig, Real changed) { +// Real *data = data_; +// for (MatrixIndexT i = 0; i < dim_; i++) +// if (data[i] == orig) data[i] = changed; +//} + + +// template +// template +// void VectorBase::MulElements(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.Dim()); +// const OtherReal *other_ptr = v.Data(); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] *= other_ptr[i]; +//} +//} +//// instantiate template. +// template +// void VectorBase::MulElements(const VectorBase &v); +// template +// void VectorBase::MulElements(const VectorBase &v); + + +// template +// void VectorBase::AddVecVec(Real alpha, const VectorBase &v, +// const VectorBase &r, Real beta) { +// KALDI_ASSERT(v.data_ != this->data_ && r.data_ != this->data_); +//// We pretend that v is a band-diagonal matrix. +// KALDI_ASSERT(dim_ == v.dim_ && dim_ == r.dim_); +// cblas_Xgbmv(kNoTrans, dim_, dim_, 0, 0, alpha, v.data_, 1, +// r.data_, 1, beta, this->data_, 1); +//} + + +// template +// void VectorBase::DivElements(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] /= v.data_[i]; +//} +//} + +// template +// template +// void VectorBase::DivElements(const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.Dim()); +// const OtherReal *other_ptr = v.Data(); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] /= other_ptr[i]; +//} +//} +//// instantiate template. +// template +// void VectorBase::DivElements(const VectorBase &v); +// template +// void VectorBase::DivElements(const VectorBase &v); + +// template +// void VectorBase::AddVecDivVec(Real alpha, const VectorBase &v, +// const VectorBase &rr, Real beta) { +// KALDI_ASSERT((dim_ == v.dim_ && dim_ == rr.dim_)); +// for (MatrixIndexT i = 0; i < dim_; i++) { +// data_[i] = alpha * v.data_[i]/rr.data_[i] + beta * data_[i] ; +//} +//} + +// template +// template +// void VectorBase::AddVec(const Real alpha, const VectorBase +// &v) { +// KALDI_ASSERT(dim_ == v.dim_); +//// remove __restrict__ if it causes compilation problems. +// Real *__restrict__ data = data_; +// OtherReal *__restrict__ other_data = v.data_; +// MatrixIndexT dim = dim_; +// if (alpha != 1.0) +// for (MatrixIndexT i = 0; i < dim; i++) +// data[i] += alpha * other_data[i]; +// else +// for (MatrixIndexT i = 0; i < dim; i++) +// data[i] += other_data[i]; +//} + +// template +// void VectorBase::AddVec(const float alpha, const VectorBase +// &v); +// template +// void VectorBase::AddVec(const double alpha, const VectorBase +// &v); + +// template +// template +// void VectorBase::AddVec2(const Real alpha, const VectorBase +// &v) { +// KALDI_ASSERT(dim_ == v.dim_); +//// remove __restrict__ if it causes compilation problems. +// Real *__restrict__ data = data_; +// OtherReal *__restrict__ other_data = v.data_; +// MatrixIndexT dim = dim_; +// if (alpha != 1.0) +// for (MatrixIndexT i = 0; i < dim; i++) +// data[i] += alpha * other_data[i] * other_data[i]; +// else +// for (MatrixIndexT i = 0; i < dim; i++) +// data[i] += other_data[i] * other_data[i]; +//} + +// template +// void VectorBase::AddVec2(const float alpha, const VectorBase +// &v); +// template +// void VectorBase::AddVec2(const double alpha, const VectorBase +// &v); + + +template +void VectorBase::Read(std::istream &is, bool binary) { + // In order to avoid rewriting this, we just declare a Vector and + // use it to read the data, then copy. + Vector tmp; + tmp.Read(is, binary); + if (tmp.Dim() != Dim()) + KALDI_ERR << "VectorBase::Read, size mismatch " << Dim() + << " vs. " << tmp.Dim(); + CopyFromVec(tmp); +} + + +template +void Vector::Read(std::istream &is, bool binary) { + std::ostringstream specific_error; + MatrixIndexT pos_at_start = is.tellg(); + + if (binary) { + int peekval = Peek(is, binary); + const char *my_token = (sizeof(Real) == 4 ? "FV" : "DV"); + char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); + if (peekval == other_token_start) { // need to instantiate the other + // type to read it. + typedef typename OtherReal::Real OtherType; // if Real == + // float, + // OtherType == + // double, and + // vice versa. + Vector other(this->Dim()); + other.Read(is, binary); // add is false at this point. + if (this->Dim() != other.Dim()) this->Resize(other.Dim()); + this->CopyFromVec(other); + return; + } + std::string token; + ReadToken(is, binary, &token); + if (token != my_token) { + if (token.length() > 20) token = token.substr(0, 17) + "..."; + specific_error << ": Expected token " << my_token << ", got " + << token; + goto bad; + } + int32 size; + ReadBasicType(is, binary, &size); // throws on error. + if ((MatrixIndexT)size != this->Dim()) this->Resize(size); + if (size > 0) + is.read(reinterpret_cast(this->data_), sizeof(Real) * size); + if (is.fail()) { + specific_error + << "Error reading vector data (binary mode); truncated " + "stream? (size = " + << size << ")"; + goto bad; + } + return; + } else { // Text mode reading; format is " [ 1.1 2.0 3.4 ]\n" + std::string s; + is >> s; + // if ((s.compare("DV") == 0) || (s.compare("FV") == 0)) { // Back + // compatibility. + // is >> s; // get dimension + // is >> s; // get "[" + // } + if (is.fail()) { + specific_error << "EOF while trying to read vector."; + goto bad; + } + if (s.compare("[]") == 0) { + Resize(0); + return; + } // tolerate this variant. + if (s.compare("[")) { + if (s.length() > 20) s = s.substr(0, 17) + "..."; + specific_error << "Expected \"[\" but got " << s; + goto bad; + } + std::vector data; + while (1) { + int i = is.peek(); + if (i == '-' || (i >= '0' && i <= '9')) { // common cases first. + Real r; + is >> r; + if (is.fail()) { + specific_error << "Failed to read number."; + goto bad; + } + if (!std::isspace(is.peek()) && is.peek() != ']') { + specific_error << "Expected whitespace after number."; + goto bad; + } + data.push_back(r); + // But don't eat whitespace... we want to check that it's not + // newlines + // which would be valid only for a matrix. + } else if (i == ' ' || i == '\t') { + is.get(); + } else if (i == ']') { + is.get(); // eat the ']' + this->Resize(data.size()); + for (size_t j = 0; j < data.size(); j++) + this->data_[j] = data[j]; + i = is.peek(); + if (static_cast(i) == '\r') { + is.get(); + is.get(); // get \r\n (must eat what we wrote) + } else if (static_cast(i) == '\n') { + is.get(); + } // get \n (must eat what we wrote) + if (is.fail()) { + KALDI_WARN << "After end of vector data, read error."; + // we got the data we needed, so just warn for this error. + } + return; // success. + } else if (i == -1) { + specific_error << "EOF while reading vector data."; + goto bad; + } else if (i == '\n' || i == '\r') { + specific_error << "Newline found while reading vector (maybe " + "it's a matrix?)"; + goto bad; + } else { + is >> s; // read string. + if (!KALDI_STRCASECMP(s.c_str(), "inf") || + !KALDI_STRCASECMP(s.c_str(), "infinity")) { + data.push_back(std::numeric_limits::infinity()); + KALDI_WARN << "Reading infinite value into vector."; + } else if (!KALDI_STRCASECMP(s.c_str(), "nan")) { + data.push_back(std::numeric_limits::quiet_NaN()); + KALDI_WARN << "Reading NaN value into vector."; + } else { + if (s.length() > 20) s = s.substr(0, 17) + "..."; + specific_error << "Expecting numeric vector data, got " + << s; + goto bad; + } + } + } + } +// we never reach this line (the while loop returns directly). +bad: + KALDI_ERR << "Failed to read vector from stream. " << specific_error.str() + << " File position at start is " << pos_at_start << ", currently " + << is.tellg(); +} + + +template +void VectorBase::Write(std::ostream &os, bool binary) const { + if (!os.good()) { + KALDI_ERR << "Failed to write vector to stream: stream not good"; + } + if (binary) { + std::string my_token = (sizeof(Real) == 4 ? "FV" : "DV"); + WriteToken(os, binary, my_token); + + int32 size = Dim(); // make the size 32-bit on disk. + KALDI_ASSERT(Dim() == (MatrixIndexT)size); + WriteBasicType(os, binary, size); + os.write(reinterpret_cast(Data()), sizeof(Real) * size); + } else { + os << " [ "; + for (MatrixIndexT i = 0; i < Dim(); i++) os << (*this)(i) << " "; + os << "]\n"; + } + if (!os.good()) KALDI_ERR << "Failed to write vector to stream"; +} + + +// template +// void VectorBase::AddVec2(const Real alpha, const VectorBase &v) { +// KALDI_ASSERT(dim_ == v.dim_); +// for (MatrixIndexT i = 0; i < dim_; i++) +// data_[i] += alpha * v.data_[i] * v.data_[i]; +//} + +//// this <-- beta*this + alpha*M*v. +// template +// void VectorBase::AddTpVec(const Real alpha, const TpMatrix &M, +// const MatrixTransposeType trans, +// const VectorBase &v, +// const Real beta) { +// KALDI_ASSERT(dim_ == v.dim_ && dim_ == M.NumRows()); +// if (beta == 0.0) { +// if (&v != this) CopyFromVec(v); +// MulTp(M, trans); +// if (alpha != 1.0) Scale(alpha); +//} else { +// Vector tmp(v); +// tmp.MulTp(M, trans); +// if (beta != 1.0) Scale(beta); // *this <-- beta * *this +// AddVec(alpha, tmp); // *this += alpha * M * v +//} +//} + +// template +// Real VecMatVec(const VectorBase &v1, const MatrixBase &M, +// const VectorBase &v2) { +// KALDI_ASSERT(v1.Dim() == M.NumRows() && v2.Dim() == M.NumCols()); +// Vector vtmp(M.NumRows()); +// vtmp.AddMatVec(1.0, M, kNoTrans, v2, 0.0); +// return VecVec(v1, vtmp); +//} + +// template +// float VecMatVec(const VectorBase &v1, const MatrixBase &M, +// const VectorBase &v2); +// template +// double VecMatVec(const VectorBase &v1, const MatrixBase &M, +// const VectorBase &v2); + +template +void Vector::Swap(Vector *other) { + std::swap(this->data_, other->data_); + std::swap(this->dim_, other->dim_); +} + + +// template +// void VectorBase::AddDiagMat2( +// Real alpha, const MatrixBase &M, +// MatrixTransposeType trans, Real beta) { +// if (trans == kNoTrans) { +// KALDI_ASSERT(this->dim_ == M.NumRows()); +// MatrixIndexT rows = this->dim_, cols = M.NumCols(), +// mat_stride = M.Stride(); +// Real *data = this->data_; +// const Real *mat_data = M.Data(); +// for (MatrixIndexT i = 0; i < rows; i++, mat_data += mat_stride, data++) +//*data = beta * *data + alpha * cblas_Xdot(cols,mat_data,1,mat_data,1); +//} else { +// KALDI_ASSERT(this->dim_ == M.NumCols()); +// MatrixIndexT rows = M.NumRows(), cols = this->dim_, +// mat_stride = M.Stride(); +// Real *data = this->data_; +// const Real *mat_data = M.Data(); +// for (MatrixIndexT i = 0; i < cols; i++, mat_data++, data++) +//*data = beta * *data + alpha * cblas_Xdot(rows, mat_data, mat_stride, +// mat_data, mat_stride); +//} +//} + +// template +// void VectorBase::AddDiagMatMat( +// Real alpha, +// const MatrixBase &M, MatrixTransposeType transM, +// const MatrixBase &N, MatrixTransposeType transN, +// Real beta) { +// MatrixIndexT dim = this->dim_, +// M_col_dim = (transM == kTrans ? M.NumRows() : M.NumCols()), +// N_row_dim = (transN == kTrans ? N.NumCols() : N.NumRows()); +// KALDI_ASSERT(M_col_dim == N_row_dim); // this is the dimension we sum over +// MatrixIndexT M_row_stride = M.Stride(), M_col_stride = 1; +// if (transM == kTrans) std::swap(M_row_stride, M_col_stride); +// MatrixIndexT N_row_stride = N.Stride(), N_col_stride = 1; +// if (transN == kTrans) std::swap(N_row_stride, N_col_stride); + +// Real *data = this->data_; +// const Real *Mdata = M.Data(), *Ndata = N.Data(); +// for (MatrixIndexT i = 0; i < dim; i++, Mdata += M_row_stride, Ndata += +// N_col_stride, data++) { +//*data = beta * *data + alpha * cblas_Xdot(M_col_dim, Mdata, M_col_stride, +// Ndata, N_row_stride); +//} +//} + + +template class Vector; +template class Vector; +template class VectorBase; +template class VectorBase; + +} // namespace kaldi diff --git a/runtime/engine/common/matrix/kaldi-vector.h b/runtime/engine/common/matrix/kaldi-vector.h new file mode 100644 index 00000000..461e026d --- /dev/null +++ b/runtime/engine/common/matrix/kaldi-vector.h @@ -0,0 +1,352 @@ +// matrix/kaldi-vector.h + +// Copyright 2009-2012 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University (Author: Arnab Ghoshal); +// Ariya Rastrow; Petr Schwarz; Yanmin Qian; +// Karel Vesely; Go Vivace Inc.; Arnab Ghoshal +// Wei Shi; +// 2015 Guoguo Chen +// 2017 Daniel Galvez +// 2019 Yiwen Shao + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_VECTOR_H_ +#define KALDI_MATRIX_KALDI_VECTOR_H_ 1 + +#include "matrix/matrix-common.h" + +namespace kaldi { + +/// \addtogroup matrix_group +/// @{ + +/// Provides a vector abstraction class. +/// This class provides a way to work with vectors in kaldi. +/// It encapsulates basic operations and memory optimizations. +template +class VectorBase { + public: + /// Set vector to all zeros. + void SetZero(); + + /// Returns true if matrix is all zeros. + bool IsZero(Real cutoff = 1.0e-06) const; // replace magic number + + /// Set all members of a vector to a specified value. + void Set(Real f); + + /// Returns the dimension of the vector. + inline MatrixIndexT Dim() const { return dim_; } + + /// Returns the size in memory of the vector, in bytes. + inline MatrixIndexT SizeInBytes() const { return (dim_ * sizeof(Real)); } + + /// Returns a pointer to the start of the vector's data. + inline Real *Data() { return data_; } + + /// Returns a pointer to the start of the vector's data (const). + inline const Real *Data() const { return data_; } + + /// Indexing operator (const). + inline Real operator()(MatrixIndexT i) const { + KALDI_PARANOID_ASSERT(static_cast(i) < + static_cast(dim_)); + return *(data_ + i); + } + + /// Indexing operator (non-const). + inline Real &operator()(MatrixIndexT i) { + KALDI_PARANOID_ASSERT(static_cast(i) < + static_cast(dim_)); + return *(data_ + i); + } + + /** @brief Returns a sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + SubVector Range(const MatrixIndexT o, const MatrixIndexT l) { + return SubVector(*this, o, l); + } + + /** @brief Returns a const sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + const SubVector Range(const MatrixIndexT o, + const MatrixIndexT l) const { + return SubVector(*this, o, l); + } + + /// Copy data from another vector (must match own size). + void CopyFromVec(const VectorBase &v); + + /// Copy data from another vector of different type (double vs. float) + template + void CopyFromVec(const VectorBase &v); + + /// Performs a row stack of the matrix M + void CopyRowsFromMat(const MatrixBase &M); + template + void CopyRowsFromMat(const MatrixBase &M); + + /// Performs a column stack of the matrix M + void CopyColsFromMat(const MatrixBase &M); + + /// Extracts a row of the matrix M. Could also do this with + /// this->Copy(M[row]). + void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); + /// Extracts a row of the matrix M with type conversion. + template + void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); + + /// Extracts a column of the matrix M. + template + void CopyColFromMat(const MatrixBase &M, MatrixIndexT col); + + /// Reads from C++ stream (option to add to existing contents). + /// Throws exception on failure + void Read(std::istream &in, bool binary); + + /// Writes to C++ stream (option to write in binary). + void Write(std::ostream &Out, bool binary) const; + + friend class VectorBase; + friend class VectorBase; + + protected: + /// Destructor; does not deallocate memory, this is handled by child + /// classes. + /// This destructor is protected so this object can only be + /// deleted via a child. + ~VectorBase() {} + + /// Empty initializer, corresponds to vector of zero size. + explicit VectorBase() : data_(NULL), dim_(0) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + + /// data memory area + Real *data_; + /// dimension of vector + MatrixIndexT dim_; + KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase); +}; // class VectorBase + +/** @brief A class representing a vector. + * + * This class provides a way to work with vectors in kaldi. + * It encapsulates basic operations and memory optimizations. */ +template +class Vector : public VectorBase { + public: + /// Constructor that takes no arguments. Initializes to empty. + Vector() : VectorBase() {} + + /// Constructor with specific size. Sets to all-zero by default + /// if set_zero == false, memory contents are undefined. + explicit Vector(const MatrixIndexT s, + MatrixResizeType resize_type = kSetZero) + : VectorBase() { + Resize(s, resize_type); + } + + /// Copy constructor from CUDA vector + /// This is defined in ../cudamatrix/cu-vector.h + // template + // explicit Vector(const CuVectorBase &cu); + + /// Copy constructor. The need for this is controversial. + Vector(const Vector &v) + : VectorBase() { // (cannot be explicit) + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Copy-constructor from base-class, needed to copy from SubVector. + explicit Vector(const VectorBase &v) : VectorBase() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Type conversion constructor. + template + explicit Vector(const VectorBase &v) : VectorBase() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + // Took this out since it is unsafe : Arnab + // /// Constructor from a pointer and a size; copies the data to a location + // /// it owns. + // Vector(const Real* Data, const MatrixIndexT s): VectorBase() { + // Resize(s); + // CopyFromPtr(Data, s); + // } + + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(Vector *other); + + /// Destructor. Deallocates memory. + ~Vector() { Destroy(); } + + /// Read function using C++ streams. Can also add to existing contents + /// of matrix. + void Read(std::istream &in, bool binary); + + /// Set vector to a specified size (can be zero). + /// The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// This function takes time proportional to the number of data elements. + void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero); + + /// Remove one element and shifts later elements down. + void RemoveElement(MatrixIndexT i); + + /// Assignment operator. + Vector &operator=(const Vector &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + + /// Assignment operator that takes VectorBase. + Vector &operator=(const VectorBase &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + + private: + /// Init assumes the current contents of the class are invalid (i.e. junk or + /// has already been freed), and it sets the vector to newly allocated + /// memory + /// with the specified dimension. dim == 0 is acceptable. The memory + /// contents + /// pointed to by data_ will be undefined. + void Init(const MatrixIndexT dim); + + /// Destroy function, called internally. + void Destroy(); +}; + + +/// Represents a non-allocating general vector which can be defined +/// as a sub-vector of higher-level vector [or as the row of a matrix]. +template +class SubVector : public VectorBase { + public: + /// Constructor from a Vector or SubVector. + /// SubVectors are not const-safe and it's very hard to make them + /// so for now we just give up. This function contains const_cast. + SubVector(const VectorBase &t, + const MatrixIndexT origin, + const MatrixIndexT length) + : VectorBase() { + // following assert equiv to origin>=0 && length>=0 && + // origin+length <= rt.dim_ + KALDI_ASSERT(static_cast(origin) + + static_cast(length) <= + static_cast(t.Dim())); + VectorBase::data_ = const_cast(t.Data() + origin); + VectorBase::dim_ = length; + } + + /// This constructor initializes the vector to point at the contents + /// of this packed matrix (SpMatrix or TpMatrix). + // SubVector(const PackedMatrix &M) { + // VectorBase::data_ = const_cast (M.Data()); + // VectorBase::dim_ = (M.NumRows()*(M.NumRows()+1))/2; + //} + + /// Copy constructor + SubVector(const SubVector &other) : VectorBase() { + // this copy constructor needed for Range() to work in base class. + VectorBase::data_ = other.data_; + VectorBase::dim_ = other.dim_; + } + + /// Constructor from a pointer to memory and a length. Keeps a pointer + /// to the data but does not take ownership (will never delete). + /// Caution: this constructor enables you to evade const constraints. + SubVector(const Real *data, MatrixIndexT length) : VectorBase() { + VectorBase::data_ = const_cast(data); + VectorBase::dim_ = length; + } + + /// This operation does not preserve const-ness, so be careful. + SubVector(const MatrixBase &matrix, MatrixIndexT row) { + VectorBase::data_ = const_cast(matrix.RowData(row)); + VectorBase::dim_ = matrix.NumCols(); + } + + ~SubVector() {} ///< Destructor (does nothing; no pointers are owned here). + + private: + /// Disallow assignment operator. + SubVector &operator=(const SubVector &other) {} +}; + +/// @} end of "addtogroup matrix_group" +/// \addtogroup matrix_funcs_io +/// @{ +/// Output to a C++ stream. Non-binary by default (use Write for +/// binary output). +template +std::ostream &operator<<(std::ostream &out, const VectorBase &v); + +/// Input from a C++ stream. Will automatically read text or +/// binary data from the stream. +template +std::istream &operator>>(std::istream &in, VectorBase &v); + +/// Input from a C++ stream. Will automatically read text or +/// binary data from the stream. +template +std::istream &operator>>(std::istream &in, Vector &v); +/// @} end of \addtogroup matrix_funcs_io + +/// \addtogroup matrix_funcs_scalar +/// @{ + + +// template +// bool ApproxEqual(const VectorBase &a, +// const VectorBase &b, Real tol = 0.01) { +// return a.ApproxEqual(b, tol); +//} + +// template +// inline void AssertEqual(VectorBase &a, VectorBase &b, +// float tol = 0.01) { +// KALDI_ASSERT(a.ApproxEqual(b, tol)); +//} + + +} // namespace kaldi + +// we need to include the implementation +#include "matrix/kaldi-vector-inl.h" + + +#endif // KALDI_MATRIX_KALDI_VECTOR_H_ diff --git a/speechx/speechx/kaldi/matrix/matrix-common.h b/runtime/engine/common/matrix/matrix-common.h similarity index 50% rename from speechx/speechx/kaldi/matrix/matrix-common.h rename to runtime/engine/common/matrix/matrix-common.h index f7047d71..e915db0a 100644 --- a/speechx/speechx/kaldi/matrix/matrix-common.h +++ b/runtime/engine/common/matrix/matrix-common.h @@ -27,71 +27,58 @@ namespace kaldi { // this enums equal to CblasTrans and CblasNoTrans constants from CBLAS library -// we are writing them as literals because we don't want to include here matrix/kaldi-blas.h, -// which puts many symbols into global scope (like "real") via the header f2c.h +// we are writing them as literals because we don't want to include here +// matrix/kaldi-blas.h, +// which puts many symbols into global scope (like "real") via the header f2c.h typedef enum { - kTrans = 112, // = CblasTrans - kNoTrans = 111 // = CblasNoTrans + kTrans = 112, // = CblasTrans + kNoTrans = 111 // = CblasNoTrans } MatrixTransposeType; -typedef enum { - kSetZero, - kUndefined, - kCopyData -} MatrixResizeType; +typedef enum { kSetZero, kUndefined, kCopyData } MatrixResizeType; typedef enum { - kDefaultStride, - kStrideEqualNumCols, + kDefaultStride, + kStrideEqualNumCols, } MatrixStrideType; typedef enum { - kTakeLower, - kTakeUpper, - kTakeMean, - kTakeMeanAndCheck + kTakeLower, + kTakeUpper, + kTakeMean, + kTakeMeanAndCheck } SpCopyType; -template class VectorBase; -template class Vector; -template class SubVector; -template class MatrixBase; -template class SubMatrix; -template class Matrix; -template class SpMatrix; -template class TpMatrix; -template class PackedMatrix; -template class SparseMatrix; - -// these are classes that won't be defined in this -// directory; they're mostly needed for friend declarations. -template class CuMatrixBase; -template class CuSubMatrix; -template class CuMatrix; -template class CuVectorBase; -template class CuSubVector; -template class CuVector; -template class CuPackedMatrix; -template class CuSpMatrix; -template class CuTpMatrix; -template class CuSparseMatrix; - -class CompressedMatrix; -class GeneralMatrix; +template +class VectorBase; +template +class Vector; +template +class SubVector; +template +class MatrixBase; +template +class SubMatrix; +template +class Matrix; + /// This class provides a way for switching between double and float types. -template class OtherReal { }; // useful in reading+writing routines - // to switch double and float. +template +class OtherReal {}; // useful in reading+writing routines + // to switch double and float. /// A specialized class for switching from float to double. -template<> class OtherReal { - public: - typedef double Real; +template <> +class OtherReal { + public: + typedef double Real; }; /// A specialized class for switching from double to float. -template<> class OtherReal { - public: - typedef float Real; +template <> +class OtherReal { + public: + typedef float Real; }; @@ -100,12 +87,10 @@ typedef int32 SignedMatrixIndexT; typedef uint32 UnsignedMatrixIndexT; // If you want to use size_t for the index type, do as follows instead: -//typedef size_t MatrixIndexT; -//typedef ssize_t SignedMatrixIndexT; -//typedef size_t UnsignedMatrixIndexT; - -} - +// typedef size_t MatrixIndexT; +// typedef ssize_t SignedMatrixIndexT; +// typedef size_t UnsignedMatrixIndexT; +} // namespace kaldi #endif // KALDI_MATRIX_MATRIX_COMMON_H_ diff --git a/runtime/engine/common/utils/CMakeLists.txt b/runtime/engine/common/utils/CMakeLists.txt new file mode 100644 index 00000000..14733648 --- /dev/null +++ b/runtime/engine/common/utils/CMakeLists.txt @@ -0,0 +1,28 @@ + + +set(csrc + file_utils.cc + math.cc + strings.cc + audio_process.cc + timer.cc +) + +add_library(utils ${csrc}) + +if(WITH_TESTING) + enable_testing() + + if(ANDROID) + else() # UNIX + link_libraries(gtest_main gmock) + + add_executable(strings_test strings_test.cc) + target_link_libraries(strings_test PUBLIC utils) + add_test( + NAME strings_test + COMMAND strings_test + ) + endif() +endif() + diff --git a/runtime/engine/common/utils/audio_process.cc b/runtime/engine/common/utils/audio_process.cc new file mode 100644 index 00000000..54540b85 --- /dev/null +++ b/runtime/engine/common/utils/audio_process.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "utils/audio_process.h" + +namespace ppspeech{ + +int WaveformFloatNormal(std::vector* waveform) { + int tot_samples = waveform->size(); + for (int i = 0; i < tot_samples; i++) { + (*waveform)[i] = (*waveform)[i] / 32768.0; + } + return 0; +} + +int WaveformNormal(std::vector* waveform, + bool wav_normal, + const std::string& wav_normal_type, + float wav_norm_mul_factor) { + if (wav_normal == false) { + return 0; + } + if (wav_normal_type == "linear") { + float amax = INT32_MIN; + for (int i = 0; i < waveform->size(); ++i) { + float tmp = std::abs((*waveform)[i]); + amax = std::max(amax, tmp); + } + float factor = 1.0 / (amax + 1e-8); + for (int i = 0; i < waveform->size(); ++i) { + (*waveform)[i] = (*waveform)[i] * factor * wav_norm_mul_factor; + } + } else if (wav_normal_type == "gaussian") { + double sum = std::accumulate(waveform->begin(), waveform->end(), 0.0); + double mean = sum / waveform->size(); //均值 + + double accum = 0.0; + std::for_each(waveform->begin(), waveform->end(), [&](const double d) { + accum += (d - mean) * (d - mean); + }); + + double stdev = sqrt(accum / (waveform->size() - 1)); //方差 + stdev = std::max(stdev, 1e-8); + + for (int i = 0; i < waveform->size(); ++i) { + (*waveform)[i] = + wav_norm_mul_factor * ((*waveform)[i] - mean) / stdev; + } + } else { + printf("don't support\n"); + return -1; + } + return 0; +} + +float PowerTodb(float in, float ref_value, float amin, float top_db) { + if (amin <= 0) { + printf("amin must be strictly positive\n"); + return -1; + } + + if (ref_value <= 0) { + printf("ref_value must be strictly positive\n"); + return -1; + } + + float out = 10.0 * log10(std::max(amin, in)); + out -= 10.0 * log10(std::max(ref_value, amin)); + return out; +} + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/common/utils/audio_process.h b/runtime/engine/common/utils/audio_process.h new file mode 100644 index 00000000..164d4c07 --- /dev/null +++ b/runtime/engine/common/utils/audio_process.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +namespace ppspeech{ +int WaveformFloatNormal(std::vector* waveform); +int WaveformNormal(std::vector* waveform, + bool wav_normal, + const std::string& wav_normal_type, + float wav_norm_mul_factor); +float PowerTodb(float in, + float ref_value = 1.0, + float amin = 1e-10, + float top_db = 80.0); +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/common/utils/blank_process_test.cc b/runtime/engine/common/utils/blank_process_test.cc new file mode 100644 index 00000000..75f762ae --- /dev/null +++ b/runtime/engine/common/utils/blank_process_test.cc @@ -0,0 +1,26 @@ +#include "utils/blank_process.h" + +#include +#include + +TEST(BlankProcess, BlankProcessTest) { + std::string test_str = "我 今天 去 了 超市 花了 120 元。"; + std::string out_str = ppspeech::BlankProcess(test_str); + int ret = out_str.compare("我今天去了超市花了120元。"); + EXPECT_EQ(ret, 0); + + test_str = "how are you today"; + out_str = ppspeech::BlankProcess(test_str); + ret = out_str.compare("how are you today"); + EXPECT_EQ(ret, 0); + + test_str = "我 的 paper 在 哪里?"; + out_str = ppspeech::BlankProcess(test_str); + ret = out_str.compare("我的paper在哪里?"); + EXPECT_EQ(ret, 0); + + test_str = "我 今天 去 了 超市 花了 120 元。"; + out_str = ppspeech::BlankProcess(test_str); + ret = out_str.compare("我今天去了超市花了120元。"); + EXPECT_EQ(ret, 0); +} \ No newline at end of file diff --git a/speechx/speechx/utils/file_utils.cc b/runtime/engine/common/utils/file_utils.cc similarity index 61% rename from speechx/speechx/utils/file_utils.cc rename to runtime/engine/common/utils/file_utils.cc index c42a642c..385f2b65 100644 --- a/speechx/speechx/utils/file_utils.cc +++ b/runtime/engine/common/utils/file_utils.cc @@ -14,6 +14,8 @@ #include "utils/file_utils.h" +#include + namespace ppspeech { bool ReadFileToVector(const std::string& filename, @@ -40,4 +42,31 @@ std::string ReadFile2String(const std::string& path) { return std::string((std::istreambuf_iterator(input_file)), std::istreambuf_iterator()); } + +bool FileExists(const std::string& strFilename) { + // this funciton if from: + // https://github.com/kaldi-asr/kaldi/blob/master/src/fstext/deterministic-fst-test.cc + struct stat stFileInfo; + bool blnReturn; + int intStat; + + // Attempt to get the file attributes + intStat = stat(strFilename.c_str(), &stFileInfo); + if (intStat == 0) { + // We were able to get the file attributes + // so the file obviously exists. + blnReturn = true; + } else { + // We were not able to get the file attributes. + // This may mean that we don't have permission to + // access the folder which contains this file. If you + // need to do that level of checking, lookup the + // return values of stat which will give you + // more details on why stat failed. + blnReturn = false; + } + + return blnReturn; +} + } // namespace ppspeech diff --git a/speechx/speechx/utils/file_utils.h b/runtime/engine/common/utils/file_utils.h similarity index 94% rename from speechx/speechx/utils/file_utils.h rename to runtime/engine/common/utils/file_utils.h index a471e024..420740db 100644 --- a/speechx/speechx/utils/file_utils.h +++ b/runtime/engine/common/utils/file_utils.h @@ -20,4 +20,7 @@ bool ReadFileToVector(const std::string& filename, std::vector* data); std::string ReadFile2String(const std::string& path); + +bool FileExists(const std::string& filename); + } // namespace ppspeech diff --git a/speechx/speechx/utils/math.cc b/runtime/engine/common/utils/math.cc similarity index 97% rename from speechx/speechx/utils/math.cc rename to runtime/engine/common/utils/math.cc index 71656cb3..1f0c9c93 100644 --- a/speechx/speechx/utils/math.cc +++ b/runtime/engine/common/utils/math.cc @@ -15,13 +15,14 @@ // limitations under the License. #include "utils/math.h" +#include "base/basic_types.h" #include #include #include +#include #include - -#include "base/common.h" +#include namespace ppspeech { diff --git a/speechx/speechx/utils/math.h b/runtime/engine/common/utils/math.h similarity index 100% rename from speechx/speechx/utils/math.h rename to runtime/engine/common/utils/math.h diff --git a/runtime/engine/common/utils/picojson.h b/runtime/engine/common/utils/picojson.h new file mode 100644 index 00000000..2ac265f5 --- /dev/null +++ b/runtime/engine/common/utils/picojson.h @@ -0,0 +1,1230 @@ +/* + * Copyright 2009-2010 Cybozu Labs, Inc. + * Copyright 2011-2014 Kazuho Oku + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#ifndef picojson_h +#define picojson_h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define PICOJSON_USE_INT64 1 + +// for isnan/isinf +#if __cplusplus >= 201103L +#include +#else +extern "C" { +#ifdef _MSC_VER +#include +#elif defined(__INTEL_COMPILER) +#include +#else +#include +#endif +} +#endif + +#ifndef PICOJSON_USE_RVALUE_REFERENCE +#if (defined(__cpp_rvalue_references) && __cpp_rvalue_references >= 200610) || \ + (defined(_MSC_VER) && _MSC_VER >= 1600) +#define PICOJSON_USE_RVALUE_REFERENCE 1 +#else +#define PICOJSON_USE_RVALUE_REFERENCE 0 +#endif +#endif // PICOJSON_USE_RVALUE_REFERENCE + +#ifndef PICOJSON_NOEXCEPT +#if PICOJSON_USE_RVALUE_REFERENCE +#define PICOJSON_NOEXCEPT noexcept +#else +#define PICOJSON_NOEXCEPT throw() +#endif +#endif + +// experimental support for int64_t (see README.mkdn for detail) +#ifdef PICOJSON_USE_INT64 +#define __STDC_FORMAT_MACROS +#include +#if __cplusplus >= 201103L +#include +#else +extern "C" { +#include +} +#endif +#endif + +// to disable the use of localeconv(3), set PICOJSON_USE_LOCALE to 0 +#ifndef PICOJSON_USE_LOCALE +#define PICOJSON_USE_LOCALE 1 +#endif +#if PICOJSON_USE_LOCALE +extern "C" { +#include +} +#endif + +#ifndef PICOJSON_ASSERT +#define PICOJSON_ASSERT(e) \ + do { \ + if (!(e)) throw std::runtime_error(#e); \ + } while (0) +#endif + +#ifdef _MSC_VER +#define SNPRINTF _snprintf_s +#pragma warning(push) +#pragma warning(disable : 4244) // conversion from int to char +#pragma warning(disable : 4127) // conditional expression is constant +#pragma warning(disable : 4702) // unreachable code +#pragma warning(disable : 4706) // assignment within conditional expression +#else +#define SNPRINTF snprintf +#endif + +namespace picojson { + +enum { + null_type, + boolean_type, + number_type, + string_type, + array_type, + object_type +#ifdef PICOJSON_USE_INT64 + , + int64_type +#endif +}; + +enum { INDENT_WIDTH = 2, DEFAULT_MAX_DEPTHS = 100 }; + +struct null {}; + +class value { + public: + typedef std::vector array; + typedef std::map object; + union _storage { + bool boolean_; + double number_; +#ifdef PICOJSON_USE_INT64 + int64_t int64_; +#endif + std::string *string_; + array *array_; + object *object_; + }; + + protected: + int type_; + _storage u_; + + public: + value(); + value(int type, bool); + explicit value(bool b); +#ifdef PICOJSON_USE_INT64 + explicit value(int64_t i); +#endif + explicit value(double n); + explicit value(const std::string &s); + explicit value(const array &a); + explicit value(const object &o); +#if PICOJSON_USE_RVALUE_REFERENCE + explicit value(std::string &&s); + explicit value(array &&a); + explicit value(object &&o); +#endif + explicit value(const char *s); + value(const char *s, size_t len); + ~value(); + value(const value &x); + value &operator=(const value &x); +#if PICOJSON_USE_RVALUE_REFERENCE + value(value &&x) PICOJSON_NOEXCEPT; + value &operator=(value &&x) PICOJSON_NOEXCEPT; +#endif + void swap(value &x) PICOJSON_NOEXCEPT; + template + bool is() const; + template + const T &get() const; + template + T &get(); + template + void set(const T &); +#if PICOJSON_USE_RVALUE_REFERENCE + template + void set(T &&); +#endif + bool evaluate_as_boolean() const; + const value &get(const size_t idx) const; + const value &get(const std::string &key) const; + value &get(const size_t idx); + value &get(const std::string &key); + + bool contains(const size_t idx) const; + bool contains(const std::string &key) const; + std::string to_str() const; + template + void serialize(Iter os, bool prettify = false) const; + std::string serialize(bool prettify = false) const; + + private: + template + value(const T *); // intentionally defined to block implicit conversion of + // pointer to bool + template + static void _indent(Iter os, int indent); + template + void _serialize(Iter os, int indent) const; + std::string _serialize(int indent) const; + void clear(); +}; + +typedef value::array array; +typedef value::object object; + +inline value::value() : type_(null_type), u_() {} + +inline value::value(int type, bool) : type_(type), u_() { + switch (type) { +#define INIT(p, v) \ + case p##type: \ + u_.p = v; \ + break + INIT(boolean_, false); + INIT(number_, 0.0); +#ifdef PICOJSON_USE_INT64 + INIT(int64_, 0); +#endif + INIT(string_, new std::string()); + INIT(array_, new array()); + INIT(object_, new object()); +#undef INIT + default: + break; + } +} + +inline value::value(bool b) : type_(boolean_type), u_() { u_.boolean_ = b; } + +#ifdef PICOJSON_USE_INT64 +inline value::value(int64_t i) : type_(int64_type), u_() { u_.int64_ = i; } +#endif + +inline value::value(double n) : type_(number_type), u_() { + if ( +#ifdef _MSC_VER + !_finite(n) +#elif __cplusplus >= 201103L + std::isnan(n) || std::isinf(n) +#else + isnan(n) || isinf(n) +#endif + ) { + throw std::overflow_error(""); + } + u_.number_ = n; +} + +inline value::value(const std::string &s) : type_(string_type), u_() { + u_.string_ = new std::string(s); +} + +inline value::value(const array &a) : type_(array_type), u_() { + u_.array_ = new array(a); +} + +inline value::value(const object &o) : type_(object_type), u_() { + u_.object_ = new object(o); +} + +#if PICOJSON_USE_RVALUE_REFERENCE +inline value::value(std::string &&s) : type_(string_type), u_() { + u_.string_ = new std::string(std::move(s)); +} + +inline value::value(array &&a) : type_(array_type), u_() { + u_.array_ = new array(std::move(a)); +} + +inline value::value(object &&o) : type_(object_type), u_() { + u_.object_ = new object(std::move(o)); +} +#endif + +inline value::value(const char *s) : type_(string_type), u_() { + u_.string_ = new std::string(s); +} + +inline value::value(const char *s, size_t len) : type_(string_type), u_() { + u_.string_ = new std::string(s, len); +} + +inline void value::clear() { + switch (type_) { +#define DEINIT(p) \ + case p##type: \ + delete u_.p; \ + break + DEINIT(string_); + DEINIT(array_); + DEINIT(object_); +#undef DEINIT + default: + break; + } +} + +inline value::~value() { clear(); } + +inline value::value(const value &x) : type_(x.type_), u_() { + switch (type_) { +#define INIT(p, v) \ + case p##type: \ + u_.p = v; \ + break + INIT(string_, new std::string(*x.u_.string_)); + INIT(array_, new array(*x.u_.array_)); + INIT(object_, new object(*x.u_.object_)); +#undef INIT + default: + u_ = x.u_; + break; + } +} + +inline value &value::operator=(const value &x) { + if (this != &x) { + value t(x); + swap(t); + } + return *this; +} + +#if PICOJSON_USE_RVALUE_REFERENCE +inline value::value(value &&x) PICOJSON_NOEXCEPT : type_(null_type), u_() { + swap(x); +} +inline value &value::operator=(value &&x) PICOJSON_NOEXCEPT { + swap(x); + return *this; +} +#endif +inline void value::swap(value &x) PICOJSON_NOEXCEPT { + std::swap(type_, x.type_); + std::swap(u_, x.u_); +} + +#define IS(ctype, jtype) \ + template <> \ + inline bool value::is() const { \ + return type_ == jtype##_type; \ + } +IS(null, null) +IS(bool, boolean) +#ifdef PICOJSON_USE_INT64 +IS(int64_t, int64) +#endif +IS(std::string, string) +IS(array, array) +IS(object, object) +#undef IS +template <> +inline bool value::is() const { + return type_ == number_type +#ifdef PICOJSON_USE_INT64 + || type_ == int64_type +#endif + ; +} + +#define GET(ctype, var) \ + template <> \ + inline const ctype &value::get() const { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" && \ + is()); \ + return var; \ + } \ + template <> \ + inline ctype &value::get() { \ + PICOJSON_ASSERT("type mismatch! call is() before get()" && \ + is()); \ + return var; \ + } +GET(bool, u_.boolean_) +GET(std::string, *u_.string_) +GET(array, *u_.array_) +GET(object, *u_.object_) +#ifdef PICOJSON_USE_INT64 +GET(double, + (type_ == int64_type && + (const_cast(this)->type_ = number_type, + (const_cast(this)->u_.number_ = u_.int64_)), + u_.number_)) +GET(int64_t, u_.int64_) +#else +GET(double, u_.number_) +#endif +#undef GET + +#define SET(ctype, jtype, setter) \ + template <> \ + inline void value::set(const ctype &_val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } +SET(bool, boolean, u_.boolean_ = _val;) +SET(std::string, string, u_.string_ = new std::string(_val);) +SET(array, array, u_.array_ = new array(_val);) +SET(object, object, u_.object_ = new object(_val);) +SET(double, number, u_.number_ = _val;) +#ifdef PICOJSON_USE_INT64 +SET(int64_t, int64, u_.int64_ = _val;) +#endif +#undef SET + +#if PICOJSON_USE_RVALUE_REFERENCE +#define MOVESET(ctype, jtype, setter) \ + template <> \ + inline void value::set(ctype && _val) { \ + clear(); \ + type_ = jtype##_type; \ + setter \ + } +MOVESET(std::string, string, u_.string_ = new std::string(std::move(_val));) +MOVESET(array, array, u_.array_ = new array(std::move(_val));) +MOVESET(object, object, u_.object_ = new object(std::move(_val));) +#undef MOVESET +#endif + +inline bool value::evaluate_as_boolean() const { + switch (type_) { + case null_type: + return false; + case boolean_type: + return u_.boolean_; + case number_type: + return u_.number_ != 0; +#ifdef PICOJSON_USE_INT64 + case int64_type: + return u_.int64_ != 0; +#endif + case string_type: + return !u_.string_->empty(); + default: + return true; + } +} + +inline const value &value::get(const size_t idx) const { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; +} + +inline value &value::get(const size_t idx) { + static value s_null; + PICOJSON_ASSERT(is()); + return idx < u_.array_->size() ? (*u_.array_)[idx] : s_null; +} + +inline const value &value::get(const std::string &key) const { + static value s_null; + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; +} + +inline value &value::get(const std::string &key) { + static value s_null; + PICOJSON_ASSERT(is()); + object::iterator i = u_.object_->find(key); + return i != u_.object_->end() ? i->second : s_null; +} + +inline bool value::contains(const size_t idx) const { + PICOJSON_ASSERT(is()); + return idx < u_.array_->size(); +} + +inline bool value::contains(const std::string &key) const { + PICOJSON_ASSERT(is()); + object::const_iterator i = u_.object_->find(key); + return i != u_.object_->end(); +} + +inline std::string value::to_str() const { + switch (type_) { + case null_type: + return "null"; + case boolean_type: + return u_.boolean_ ? "true" : "false"; +#ifdef PICOJSON_USE_INT64 + case int64_type: { + char buf[sizeof("-9223372036854775808")]; + SNPRINTF(buf, sizeof(buf), "%" PRId64, u_.int64_); + return buf; + } +#endif + case number_type: { + char buf[256]; + double tmp; + SNPRINTF( + buf, + sizeof(buf), + fabs(u_.number_) < (1ULL << 53) && modf(u_.number_, &tmp) == 0 + ? "%.f" + : "%.17g", + u_.number_); +#if PICOJSON_USE_LOCALE + char *decimal_point = localeconv()->decimal_point; + if (strcmp(decimal_point, ".") != 0) { + size_t decimal_point_len = strlen(decimal_point); + for (char *p = buf; *p != '\0'; ++p) { + if (strncmp(p, decimal_point, decimal_point_len) == 0) { + return std::string(buf, p) + "." + + (p + decimal_point_len); + } + } + } +#endif + return buf; + } + case string_type: + return *u_.string_; + case array_type: + return "array"; + case object_type: + return "object"; + default: + PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + } + return std::string(); +} + +template +void copy(const std::string &s, Iter oi) { + std::copy(s.begin(), s.end(), oi); +} + +template +struct serialize_str_char { + Iter oi; + void operator()(char c) { + switch (c) { +#define MAP(val, sym) \ + case val: \ + copy(sym, oi); \ + break + MAP('"', "\\\""); + MAP('\\', "\\\\"); + MAP('/', "\\/"); + MAP('\b', "\\b"); + MAP('\f', "\\f"); + MAP('\n', "\\n"); + MAP('\r', "\\r"); + MAP('\t', "\\t"); +#undef MAP + default: + if (static_cast(c) < 0x20 || c == 0x7f) { + char buf[7]; + SNPRINTF(buf, sizeof(buf), "\\u%04x", c & 0xff); + copy(buf, buf + 6, oi); + } else { + *oi++ = c; + } + break; + } + } +}; + +template +void serialize_str(const std::string &s, Iter oi) { + *oi++ = '"'; + serialize_str_char process_char = {oi}; + std::for_each(s.begin(), s.end(), process_char); + *oi++ = '"'; +} + +template +void value::serialize(Iter oi, bool prettify) const { + return _serialize(oi, prettify ? 0 : -1); +} + +inline std::string value::serialize(bool prettify) const { + return _serialize(prettify ? 0 : -1); +} + +template +void value::_indent(Iter oi, int indent) { + *oi++ = '\n'; + for (int i = 0; i < indent * INDENT_WIDTH; ++i) { + *oi++ = ' '; + } +} + +template +void value::_serialize(Iter oi, int indent) const { + switch (type_) { + case string_type: + serialize_str(*u_.string_, oi); + break; + case array_type: { + *oi++ = '['; + if (indent != -1) { + ++indent; + } + for (array::const_iterator i = u_.array_->begin(); + i != u_.array_->end(); + ++i) { + if (i != u_.array_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + i->_serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (!u_.array_->empty()) { + _indent(oi, indent); + } + } + *oi++ = ']'; + break; + } + case object_type: { + *oi++ = '{'; + if (indent != -1) { + ++indent; + } + for (object::const_iterator i = u_.object_->begin(); + i != u_.object_->end(); + ++i) { + if (i != u_.object_->begin()) { + *oi++ = ','; + } + if (indent != -1) { + _indent(oi, indent); + } + serialize_str(i->first, oi); + *oi++ = ':'; + if (indent != -1) { + *oi++ = ' '; + } + i->second._serialize(oi, indent); + } + if (indent != -1) { + --indent; + if (!u_.object_->empty()) { + _indent(oi, indent); + } + } + *oi++ = '}'; + break; + } + default: + copy(to_str(), oi); + break; + } + if (indent == 0) { + *oi++ = '\n'; + } +} + +inline std::string value::_serialize(int indent) const { + std::string s; + _serialize(std::back_inserter(s), indent); + return s; +} + +template +class input { + protected: + Iter cur_, end_; + bool consumed_; + int line_; + + public: + input(const Iter &first, const Iter &last) + : cur_(first), end_(last), consumed_(false), line_(1) {} + int getc() { + if (consumed_) { + if (*cur_ == '\n') { + ++line_; + } + ++cur_; + } + if (cur_ == end_) { + consumed_ = false; + return -1; + } + consumed_ = true; + return *cur_ & 0xff; + } + void ungetc() { consumed_ = false; } + Iter cur() const { + if (consumed_) { + input *self = const_cast *>(this); + self->consumed_ = false; + ++self->cur_; + } + return cur_; + } + int line() const { return line_; } + void skip_ws() { + while (1) { + int ch = getc(); + if (!(ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r')) { + ungetc(); + break; + } + } + } + bool expect(const int expected) { + skip_ws(); + if (getc() != expected) { + ungetc(); + return false; + } + return true; + } + bool match(const std::string &pattern) { + for (std::string::const_iterator pi(pattern.begin()); + pi != pattern.end(); + ++pi) { + if (getc() != *pi) { + ungetc(); + return false; + } + } + return true; + } +}; + +template +inline int _parse_quadhex(input &in) { + int uni_ch = 0, hex; + for (int i = 0; i < 4; i++) { + if ((hex = in.getc()) == -1) { + return -1; + } + if ('0' <= hex && hex <= '9') { + hex -= '0'; + } else if ('A' <= hex && hex <= 'F') { + hex -= 'A' - 0xa; + } else if ('a' <= hex && hex <= 'f') { + hex -= 'a' - 0xa; + } else { + in.ungetc(); + return -1; + } + uni_ch = uni_ch * 16 + hex; + } + return uni_ch; +} + +template +inline bool _parse_codepoint(String &out, input &in) { + int uni_ch; + if ((uni_ch = _parse_quadhex(in)) == -1) { + return false; + } + if (0xd800 <= uni_ch && uni_ch <= 0xdfff) { + if (0xdc00 <= uni_ch) { + // a second 16-bit of a surrogate pair appeared + return false; + } + // first 16-bit of surrogate pair, get the next one + if (in.getc() != '\\' || in.getc() != 'u') { + in.ungetc(); + return false; + } + int second = _parse_quadhex(in); + if (!(0xdc00 <= second && second <= 0xdfff)) { + return false; + } + uni_ch = ((uni_ch - 0xd800) << 10) | ((second - 0xdc00) & 0x3ff); + uni_ch += 0x10000; + } + if (uni_ch < 0x80) { + out.push_back(static_cast(uni_ch)); + } else { + if (uni_ch < 0x800) { + out.push_back(static_cast(0xc0 | (uni_ch >> 6))); + } else { + if (uni_ch < 0x10000) { + out.push_back(static_cast(0xe0 | (uni_ch >> 12))); + } else { + out.push_back(static_cast(0xf0 | (uni_ch >> 18))); + out.push_back( + static_cast(0x80 | ((uni_ch >> 12) & 0x3f))); + } + out.push_back(static_cast(0x80 | ((uni_ch >> 6) & 0x3f))); + } + out.push_back(static_cast(0x80 | (uni_ch & 0x3f))); + } + return true; +} + +template +inline bool _parse_string(String &out, input &in) { + while (1) { + int ch = in.getc(); + if (ch < ' ') { + in.ungetc(); + return false; + } else if (ch == '"') { + return true; + } else if (ch == '\\') { + if ((ch = in.getc()) == -1) { + return false; + } + switch (ch) { +#define MAP(sym, val) \ + case sym: \ + out.push_back(val); \ + break + MAP('"', '\"'); + MAP('\\', '\\'); + MAP('/', '/'); + MAP('b', '\b'); + MAP('f', '\f'); + MAP('n', '\n'); + MAP('r', '\r'); + MAP('t', '\t'); +#undef MAP + case 'u': + if (!_parse_codepoint(out, in)) { + return false; + } + break; + default: + return false; + } + } else { + out.push_back(static_cast(ch)); + } + } + return false; +} + +template +inline bool _parse_array(Context &ctx, input &in) { + if (!ctx.parse_array_start()) { + return false; + } + size_t idx = 0; + if (in.expect(']')) { + return ctx.parse_array_stop(idx); + } + do { + if (!ctx.parse_array_item(in, idx)) { + return false; + } + idx++; + } while (in.expect(',')); + return in.expect(']') && ctx.parse_array_stop(idx); +} + +template +inline bool _parse_object(Context &ctx, input &in) { + if (!ctx.parse_object_start()) { + return false; + } + if (in.expect('}')) { + return ctx.parse_object_stop(); + } + do { + std::string key; + if (!in.expect('"') || !_parse_string(key, in) || !in.expect(':')) { + return false; + } + if (!ctx.parse_object_item(in, key)) { + return false; + } + } while (in.expect(',')); + return in.expect('}') && ctx.parse_object_stop(); +} + +template +inline std::string _parse_number(input &in) { + std::string num_str; + while (1) { + int ch = in.getc(); + if (('0' <= ch && ch <= '9') || ch == '+' || ch == '-' || ch == 'e' || + ch == 'E') { + num_str.push_back(static_cast(ch)); + } else if (ch == '.') { +#if PICOJSON_USE_LOCALE + num_str += localeconv()->decimal_point; +#else + num_str.push_back('.'); +#endif + } else { + in.ungetc(); + break; + } + } + return num_str; +} + +template +inline bool _parse(Context &ctx, input &in) { + in.skip_ws(); + int ch = in.getc(); + switch (ch) { +#define IS(ch, text, op) \ + case ch: \ + if (in.match(text) && op) { \ + return true; \ + } else { \ + return false; \ + } + IS('n', "ull", ctx.set_null()); + IS('f', "alse", ctx.set_bool(false)); + IS('t', "rue", ctx.set_bool(true)); +#undef IS + case '"': + return ctx.parse_string(in); + case '[': + return _parse_array(ctx, in); + case '{': + return _parse_object(ctx, in); + default: + if (('0' <= ch && ch <= '9') || ch == '-') { + double f; + char *endp; + in.ungetc(); + std::string num_str(_parse_number(in)); + if (num_str.empty()) { + return false; + } +#ifdef PICOJSON_USE_INT64 + { + errno = 0; + intmax_t ival = strtoimax(num_str.c_str(), &endp, 10); + if (errno == 0 && + std::numeric_limits::min() <= ival && + ival <= std::numeric_limits::max() && + endp == num_str.c_str() + num_str.size()) { + ctx.set_int64(ival); + return true; + } + } +#endif + f = strtod(num_str.c_str(), &endp); + if (endp == num_str.c_str() + num_str.size()) { + ctx.set_number(f); + return true; + } + return false; + } + break; + } + in.ungetc(); + return false; +} + +class deny_parse_context { + public: + bool set_null() { return false; } + bool set_bool(bool) { return false; } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { return false; } +#endif + bool set_number(double) { return false; } + template + bool parse_string(input &) { + return false; + } + bool parse_array_start() { return false; } + template + bool parse_array_item(input &, size_t) { + return false; + } + bool parse_array_stop(size_t) { return false; } + bool parse_object_start() { return false; } + template + bool parse_object_item(input &, const std::string &) { + return false; + } +}; + +class default_parse_context { + protected: + value *out_; + size_t depths_; + + public: + default_parse_context(value *out, size_t depths = DEFAULT_MAX_DEPTHS) + : out_(out), depths_(depths) {} + bool set_null() { + *out_ = value(); + return true; + } + bool set_bool(bool b) { + *out_ = value(b); + return true; + } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t i) { + *out_ = value(i); + return true; + } +#endif + bool set_number(double f) { + *out_ = value(f); + return true; + } + template + bool parse_string(input &in) { + *out_ = value(string_type, false); + return _parse_string(out_->get(), in); + } + bool parse_array_start() { + if (depths_ == 0) return false; + --depths_; + *out_ = value(array_type, false); + return true; + } + template + bool parse_array_item(input &in, size_t) { + array &a = out_->get(); + a.push_back(value()); + default_parse_context ctx(&a.back(), depths_); + return _parse(ctx, in); + } + bool parse_array_stop(size_t) { + ++depths_; + return true; + } + bool parse_object_start() { + if (depths_ == 0) return false; + *out_ = value(object_type, false); + return true; + } + template + bool parse_object_item(input &in, const std::string &key) { + object &o = out_->get(); + default_parse_context ctx(&o[key], depths_); + return _parse(ctx, in); + } + bool parse_object_stop() { + ++depths_; + return true; + } + + private: + default_parse_context(const default_parse_context &); + default_parse_context &operator=(const default_parse_context &); +}; + +class null_parse_context { + protected: + size_t depths_; + + public: + struct dummy_str { + void push_back(int) {} + }; + + public: + null_parse_context(size_t depths = DEFAULT_MAX_DEPTHS) : depths_(depths) {} + bool set_null() { return true; } + bool set_bool(bool) { return true; } +#ifdef PICOJSON_USE_INT64 + bool set_int64(int64_t) { return true; } +#endif + bool set_number(double) { return true; } + template + bool parse_string(input &in) { + dummy_str s; + return _parse_string(s, in); + } + bool parse_array_start() { + if (depths_ == 0) return false; + --depths_; + return true; + } + template + bool parse_array_item(input &in, size_t) { + return _parse(*this, in); + } + bool parse_array_stop(size_t) { + ++depths_; + return true; + } + bool parse_object_start() { + if (depths_ == 0) return false; + --depths_; + return true; + } + template + bool parse_object_item(input &in, const std::string &) { + ++depths_; + return _parse(*this, in); + } + bool parse_object_stop() { return true; } + + private: + null_parse_context(const null_parse_context &); + null_parse_context &operator=(const null_parse_context &); +}; + +// obsolete, use the version below +template +inline std::string parse(value &out, Iter &pos, const Iter &last) { + std::string err; + pos = parse(out, pos, last, &err); + return err; +} + +template +inline Iter _parse(Context &ctx, + const Iter &first, + const Iter &last, + std::string *err) { + input in(first, last); + if (!_parse(ctx, in) && err != NULL) { + char buf[64]; + SNPRINTF(buf, sizeof(buf), "syntax error at line %d near: ", in.line()); + *err = buf; + while (1) { + int ch = in.getc(); + if (ch == -1 || ch == '\n') { + break; + } else if (ch >= ' ') { + err->push_back(static_cast(ch)); + } + } + } + return in.cur(); +} + +template +inline Iter parse(value &out, + const Iter &first, + const Iter &last, + std::string *err) { + default_parse_context ctx(&out); + return _parse(ctx, first, last, err); +} + +inline std::string parse(value &out, const std::string &s) { + std::string err; + parse(out, s.begin(), s.end(), &err); + return err; +} + +inline std::string parse(value &out, std::istream &is) { + std::string err; + parse(out, + std::istreambuf_iterator(is.rdbuf()), + std::istreambuf_iterator(), + &err); + return err; +} + +template +struct last_error_t { + static std::string s; +}; +template +std::string last_error_t::s; + +inline void set_last_error(const std::string &s) { last_error_t::s = s; } + +inline const std::string &get_last_error() { return last_error_t::s; } + +inline bool operator==(const value &x, const value &y) { + if (x.is()) return y.is(); +#define PICOJSON_CMP(type) \ + if (x.is()) return y.is() && x.get() == y.get() + PICOJSON_CMP(bool); + PICOJSON_CMP(double); + PICOJSON_CMP(std::string); + PICOJSON_CMP(array); + PICOJSON_CMP(object); +#undef PICOJSON_CMP + PICOJSON_ASSERT(0); +#ifdef _MSC_VER + __assume(0); +#endif + return false; +} + +inline bool operator!=(const value &x, const value &y) { return !(x == y); } +} + +#if !PICOJSON_USE_RVALUE_REFERENCE +namespace std { +template <> +inline void swap(picojson::value &x, picojson::value &y) { + x.swap(y); +} +} +#endif + +inline std::istream &operator>>(std::istream &is, picojson::value &x) { + picojson::set_last_error(std::string()); + const std::string err(picojson::parse(x, is)); + if (!err.empty()) { + picojson::set_last_error(err); + is.setstate(std::ios::failbit); + } + return is; +} + +inline std::ostream &operator<<(std::ostream &os, const picojson::value &x) { + x.serialize(std::ostream_iterator(os)); + return os; +} +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +#endif \ No newline at end of file diff --git a/runtime/engine/common/utils/strings.cc b/runtime/engine/common/utils/strings.cc new file mode 100644 index 00000000..91954d64 --- /dev/null +++ b/runtime/engine/common/utils/strings.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "utils/strings.h" + +namespace ppspeech { + +std::vector StrSplit(const std::string& str, + const char* delim, + bool omit_empty_string) { + std::vector outs; + int start = 0; + int end = str.size(); + int found = 0; + while (found != std::string::npos) { + found = str.find_first_of(delim, start); + // start != end condition is for when the delimiter is at the end + if (!omit_empty_string || (found != start && start != end)) { + outs.push_back(str.substr(start, found - start)); + } + start = found + 1; + } + + return outs; +} + + +std::string StrJoin(const std::vector& strs, const char* delim) { + std::stringstream ss; + for (ssize_t i = 0; i < strs.size(); ++i) { + ss << strs[i]; + if (i < strs.size() - 1) { + ss << std::string(delim); + } + } + return ss.str(); +} + +std::string DelBlank(const std::string& str) { + std::string out = ""; + int ptr_in = 0; // the pointer of input string (for traversal) + int end = str.size(); + int ptr_out = -1; // the pointer of output string (last char) + while (ptr_in != end) { + while (ptr_in != end && str[ptr_in] == ' ') { + ptr_in += 1; + } + if (ptr_in == end) + return out; + if (ptr_out != -1 && isalpha(str[ptr_in]) && isalpha(str[ptr_out]) && str[ptr_in-1] == ' ') + // add a space when the last and current chars are in English and there have space(s) between them + out += ' '; + out += str[ptr_in]; + ptr_out = ptr_in; + ptr_in += 1; + } + return out; +} + +std::string AddBlank(const std::string& str) { + std::string out = ""; + int ptr = 0; // the pointer of the input string + int end = str.size(); + while (ptr != end) { + if (isalpha(str[ptr])) { + if (ptr == 0 or str[ptr-1] != ' ') + out += " "; // add pre-space for an English word + while (isalpha(str[ptr])) { + out += str[ptr]; + ptr += 1; + } + out += " "; // add post-space for an English word + } else { + out += str[ptr]; + ptr += 1; + } + } + return out; +} + +std::string ReverseFraction(const std::string& str) { + std::string out = ""; + int ptr = 0; // the pointer of the input string + int end = str.size(); + int left, right, frac; // the start index of the left tag, right tag and '/'. + left = right = frac = 0; + int len_tag = 5; // length of "" + + while (ptr != end) { + // find the position of left tag, right tag and '/'. (xxxnum1/num2) + left = str.find("", ptr); + if (left == -1) + break; + out += str.substr(ptr, left - ptr); // content before left tag (xxx) + frac = str.find("/", left); + right = str.find("", frac); + + out += str.substr(frac + 1, right - frac - 1) + '/' + + str.substr(left + len_tag, frac - left - len_tag); // num2/num1 + ptr = right + len_tag; + } + if (ptr != end) { + out += str.substr(ptr, end - ptr); + } + return out; +} + +#ifdef _MSC_VER +std::wstring ToWString(const std::string& str) { + unsigned len = str.size() * 2; + setlocale(LC_CTYPE, ""); + wchar_t* p = new wchar_t[len]; + mbstowcs(p, str.c_str(), len); + std::wstring wstr(p); + delete[] p; + return wstr; +} +#endif + +} // namespace ppspeech diff --git a/runtime/engine/common/utils/strings.h b/runtime/engine/common/utils/strings.h new file mode 100644 index 00000000..cd79ae4f --- /dev/null +++ b/runtime/engine/common/utils/strings.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace ppspeech { + +std::vector StrSplit(const std::string& str, + const char* delim, + bool omit_empty_string = true); + +std::string StrJoin(const std::vector& strs, const char* delim); + +std::string DelBlank(const std::string& str); + +std::string AddBlank(const std::string& str); + +std::string ReverseFraction(const std::string& str); + +#ifdef _MSC_VER +std::wstring ToWString(const std::string& str); +#endif + +} // namespace ppspeech diff --git a/runtime/engine/common/utils/strings_test.cc b/runtime/engine/common/utils/strings_test.cc new file mode 100644 index 00000000..058b6a01 --- /dev/null +++ b/runtime/engine/common/utils/strings_test.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "utils/strings.h" + +#include +#include + + +TEST(StringTest, StrSplitTest) { + using ::testing::ElementsAre; + + std::string test_str = "hello world"; + std::vector outs = ppspeech::StrSplit(test_str, " \t"); + EXPECT_THAT(outs, ElementsAre("hello", "world")); +} + + +TEST(StringTest, StrJoinTest) { + std::vector ins{"hello", "world"}; + std::string out = ppspeech::StrJoin(ins, " "); + EXPECT_THAT(out, "hello world"); +} + +TEST(StringText, DelBlankTest) { + std::string test_str = "我 今天 去 了 超市 花了 120 元。"; + std::string out_str = ppspeech::DelBlank(test_str); + int ret = out_str.compare("我今天去了超市花了120元。"); + EXPECT_EQ(ret, 0); + + test_str = "how are you today"; + out_str = ppspeech::DelBlank(test_str); + ret = out_str.compare("how are you today"); + EXPECT_EQ(ret, 0); + + test_str = "我 的 paper 在 哪里?"; + out_str = ppspeech::DelBlank(test_str); + ret = out_str.compare("我的paper在哪里?"); + EXPECT_EQ(ret, 0); +} + +TEST(StringTest, AddBlankTest) { + std::string test_str = "how are you"; + std::string out_str = ppspeech::AddBlank(test_str); + int ret = out_str.compare(" how are you "); + EXPECT_EQ(ret, 0); + + test_str = "欢迎来到China。"; + out_str = ppspeech::AddBlank(test_str); + ret = out_str.compare("欢迎来到 China 。"); + EXPECT_EQ(ret, 0); +} + +TEST(StringTest, ReverseFractionTest) { + std::string test_str = "3/1"; + std::string out_str = ppspeech::ReverseFraction(test_str); + int ret = out_str.compare("1/3"); + std::cout< + +#include "common/utils/timer.h" + +namespace ppspeech{ + +struct TimerImpl{ + TimerImpl() = default; + virtual ~TimerImpl() = default; + virtual void Reset() = 0; + // time in seconds + virtual double Elapsed() = 0; +}; + +class CpuTimerImpl : public TimerImpl { + public: + CpuTimerImpl() { Reset(); } + + using high_resolution_clock = std::chrono::high_resolution_clock; + + void Reset() override { begin_ = high_resolution_clock::now(); } + + // time in seconds + double Elapsed() override { + auto end = high_resolution_clock::now(); + auto dur = + std::chrono::duration_cast(end - begin_); + return dur.count() / 1000000.0; + } + + private: + high_resolution_clock::time_point begin_; +}; + +Timer::Timer() { + impl_ = std::make_unique(); +} + +Timer::~Timer() = default; + +void Timer::Reset() const { impl_->Reset(); } + +double Timer::Elapsed() const { return impl_->Elapsed(); } + + +} //namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/common/utils/timer.h b/runtime/engine/common/utils/timer.h new file mode 100644 index 00000000..6f4ae1f8 --- /dev/null +++ b/runtime/engine/common/utils/timer.h @@ -0,0 +1,39 @@ +// Copyright 2020 Xiaomi Corporation (authors: Haowen Qiu) +// Mobvoi Inc. (authors: Fangjun Kuang) +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace ppspeech { + +struct TimerImpl; + +class Timer { + public: + Timer(); + ~Timer(); + + void Reset() const; + + // time in seconds + double Elapsed() const; + + private: + std::unique_ptr impl_; +}; + +} //namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/kaldi/CMakeLists.txt b/runtime/engine/kaldi/CMakeLists.txt new file mode 100644 index 00000000..e55cecbb --- /dev/null +++ b/runtime/engine/kaldi/CMakeLists.txt @@ -0,0 +1,15 @@ +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +) + +add_subdirectory(base) +add_subdirectory(util) +if(WITH_ASR) + add_subdirectory(lat) + add_subdirectory(fstext) + add_subdirectory(decoder) + add_subdirectory(lm) + + add_subdirectory(fstbin) + add_subdirectory(lmbin) +endif() diff --git a/speechx/speechx/kaldi/base/CMakeLists.txt b/runtime/engine/kaldi/base/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/base/CMakeLists.txt rename to runtime/engine/kaldi/base/CMakeLists.txt diff --git a/speechx/speechx/kaldi/base/io-funcs-inl.h b/runtime/engine/kaldi/base/io-funcs-inl.h similarity index 100% rename from speechx/speechx/kaldi/base/io-funcs-inl.h rename to runtime/engine/kaldi/base/io-funcs-inl.h diff --git a/speechx/speechx/kaldi/base/io-funcs.cc b/runtime/engine/kaldi/base/io-funcs.cc similarity index 100% rename from speechx/speechx/kaldi/base/io-funcs.cc rename to runtime/engine/kaldi/base/io-funcs.cc diff --git a/speechx/speechx/kaldi/base/io-funcs.h b/runtime/engine/kaldi/base/io-funcs.h similarity index 100% rename from speechx/speechx/kaldi/base/io-funcs.h rename to runtime/engine/kaldi/base/io-funcs.h diff --git a/speechx/speechx/kaldi/base/kaldi-common.h b/runtime/engine/kaldi/base/kaldi-common.h similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-common.h rename to runtime/engine/kaldi/base/kaldi-common.h diff --git a/speechx/speechx/kaldi/base/kaldi-error.cc b/runtime/engine/kaldi/base/kaldi-error.cc similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-error.cc rename to runtime/engine/kaldi/base/kaldi-error.cc diff --git a/speechx/speechx/kaldi/base/kaldi-error.h b/runtime/engine/kaldi/base/kaldi-error.h similarity index 99% rename from speechx/speechx/kaldi/base/kaldi-error.h rename to runtime/engine/kaldi/base/kaldi-error.h index a9904a75..98bef74f 100644 --- a/speechx/speechx/kaldi/base/kaldi-error.h +++ b/runtime/engine/kaldi/base/kaldi-error.h @@ -181,7 +181,7 @@ private: // Also see KALDI_COMPILE_TIME_ASSERT, defined in base/kaldi-utils.h, and // KALDI_ASSERT_IS_INTEGER_TYPE and KALDI_ASSERT_IS_FLOATING_TYPE, also defined // there. -#ifndef NDEBUG +#ifdef PPS_DEBUG #define KALDI_ASSERT(cond) \ do { \ if (cond) \ diff --git a/speechx/speechx/kaldi/base/kaldi-math.cc b/runtime/engine/kaldi/base/kaldi-math.cc similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-math.cc rename to runtime/engine/kaldi/base/kaldi-math.cc diff --git a/speechx/speechx/kaldi/base/kaldi-math.h b/runtime/engine/kaldi/base/kaldi-math.h similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-math.h rename to runtime/engine/kaldi/base/kaldi-math.h diff --git a/speechx/speechx/kaldi/base/kaldi-types.h b/runtime/engine/kaldi/base/kaldi-types.h similarity index 88% rename from speechx/speechx/kaldi/base/kaldi-types.h rename to runtime/engine/kaldi/base/kaldi-types.h index c6a3e1ae..bf8a2722 100644 --- a/speechx/speechx/kaldi/base/kaldi-types.h +++ b/runtime/engine/kaldi/base/kaldi-types.h @@ -40,11 +40,23 @@ typedef float BaseFloat; #include // for discussion on what to do if you need compile kaldi -// without OpenFST, see the bottom of this this file +// without OpenFST, see the bottom of this file #ifndef COMPILE_WITHOUT_OPENFST +#ifdef WITH_ASR #include +#else +using int8 = int8_t; +using int16 = int16_t; +using int32 = int32_t; +using int64 = int64_t; + +using uint8 = uint8_t; +using uint16 = uint16_t; +using uint32 = uint32_t; +using uint64 = uint64_t; +#endif namespace kaldi { using ::int16; diff --git a/speechx/speechx/kaldi/base/kaldi-utils.cc b/runtime/engine/kaldi/base/kaldi-utils.cc similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-utils.cc rename to runtime/engine/kaldi/base/kaldi-utils.cc diff --git a/speechx/speechx/kaldi/base/kaldi-utils.h b/runtime/engine/kaldi/base/kaldi-utils.h similarity index 100% rename from speechx/speechx/kaldi/base/kaldi-utils.h rename to runtime/engine/kaldi/base/kaldi-utils.h diff --git a/speechx/speechx/kaldi/base/timer.cc b/runtime/engine/kaldi/base/timer.cc similarity index 100% rename from speechx/speechx/kaldi/base/timer.cc rename to runtime/engine/kaldi/base/timer.cc diff --git a/speechx/speechx/kaldi/base/timer.h b/runtime/engine/kaldi/base/timer.h similarity index 100% rename from speechx/speechx/kaldi/base/timer.h rename to runtime/engine/kaldi/base/timer.h diff --git a/speechx/speechx/kaldi/base/version.h b/runtime/engine/kaldi/base/version.h similarity index 100% rename from speechx/speechx/kaldi/base/version.h rename to runtime/engine/kaldi/base/version.h diff --git a/speechx/speechx/kaldi/decoder/CMakeLists.txt b/runtime/engine/kaldi/decoder/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/decoder/CMakeLists.txt rename to runtime/engine/kaldi/decoder/CMakeLists.txt diff --git a/speechx/speechx/kaldi/decoder/decodable-itf.h b/runtime/engine/kaldi/decoder/decodable-itf.h similarity index 100% rename from speechx/speechx/kaldi/decoder/decodable-itf.h rename to runtime/engine/kaldi/decoder/decodable-itf.h diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc b/runtime/engine/kaldi/decoder/lattice-faster-decoder.cc similarity index 100% rename from speechx/speechx/kaldi/decoder/lattice-faster-decoder.cc rename to runtime/engine/kaldi/decoder/lattice-faster-decoder.cc diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-decoder.h b/runtime/engine/kaldi/decoder/lattice-faster-decoder.h similarity index 100% rename from speechx/speechx/kaldi/decoder/lattice-faster-decoder.h rename to runtime/engine/kaldi/decoder/lattice-faster-decoder.h diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc b/runtime/engine/kaldi/decoder/lattice-faster-online-decoder.cc similarity index 100% rename from speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.cc rename to runtime/engine/kaldi/decoder/lattice-faster-online-decoder.cc diff --git a/speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h b/runtime/engine/kaldi/decoder/lattice-faster-online-decoder.h similarity index 100% rename from speechx/speechx/kaldi/decoder/lattice-faster-online-decoder.h rename to runtime/engine/kaldi/decoder/lattice-faster-online-decoder.h diff --git a/speechx/speechx/kaldi/fstbin/CMakeLists.txt b/runtime/engine/kaldi/fstbin/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/fstbin/CMakeLists.txt rename to runtime/engine/kaldi/fstbin/CMakeLists.txt diff --git a/speechx/speechx/kaldi/fstbin/fstaddselfloops.cc b/runtime/engine/kaldi/fstbin/fstaddselfloops.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fstaddselfloops.cc rename to runtime/engine/kaldi/fstbin/fstaddselfloops.cc diff --git a/speechx/speechx/kaldi/fstbin/fstdeterminizestar.cc b/runtime/engine/kaldi/fstbin/fstdeterminizestar.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fstdeterminizestar.cc rename to runtime/engine/kaldi/fstbin/fstdeterminizestar.cc diff --git a/speechx/speechx/kaldi/fstbin/fstisstochastic.cc b/runtime/engine/kaldi/fstbin/fstisstochastic.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fstisstochastic.cc rename to runtime/engine/kaldi/fstbin/fstisstochastic.cc diff --git a/speechx/speechx/kaldi/fstbin/fstminimizeencoded.cc b/runtime/engine/kaldi/fstbin/fstminimizeencoded.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fstminimizeencoded.cc rename to runtime/engine/kaldi/fstbin/fstminimizeencoded.cc diff --git a/speechx/speechx/kaldi/fstbin/fsttablecompose.cc b/runtime/engine/kaldi/fstbin/fsttablecompose.cc similarity index 100% rename from speechx/speechx/kaldi/fstbin/fsttablecompose.cc rename to runtime/engine/kaldi/fstbin/fsttablecompose.cc diff --git a/speechx/speechx/kaldi/fstext/CMakeLists.txt b/runtime/engine/kaldi/fstext/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/fstext/CMakeLists.txt rename to runtime/engine/kaldi/fstext/CMakeLists.txt diff --git a/speechx/speechx/kaldi/fstext/determinize-lattice-inl.h b/runtime/engine/kaldi/fstext/determinize-lattice-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/determinize-lattice-inl.h rename to runtime/engine/kaldi/fstext/determinize-lattice-inl.h diff --git a/speechx/speechx/kaldi/fstext/determinize-lattice.h b/runtime/engine/kaldi/fstext/determinize-lattice.h similarity index 100% rename from speechx/speechx/kaldi/fstext/determinize-lattice.h rename to runtime/engine/kaldi/fstext/determinize-lattice.h diff --git a/speechx/speechx/kaldi/fstext/determinize-star-inl.h b/runtime/engine/kaldi/fstext/determinize-star-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/determinize-star-inl.h rename to runtime/engine/kaldi/fstext/determinize-star-inl.h diff --git a/speechx/speechx/kaldi/fstext/determinize-star.h b/runtime/engine/kaldi/fstext/determinize-star.h similarity index 100% rename from speechx/speechx/kaldi/fstext/determinize-star.h rename to runtime/engine/kaldi/fstext/determinize-star.h diff --git a/speechx/speechx/kaldi/fstext/fstext-lib.h b/runtime/engine/kaldi/fstext/fstext-lib.h similarity index 100% rename from speechx/speechx/kaldi/fstext/fstext-lib.h rename to runtime/engine/kaldi/fstext/fstext-lib.h diff --git a/speechx/speechx/kaldi/fstext/fstext-utils-inl.h b/runtime/engine/kaldi/fstext/fstext-utils-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/fstext-utils-inl.h rename to runtime/engine/kaldi/fstext/fstext-utils-inl.h diff --git a/speechx/speechx/kaldi/fstext/fstext-utils.h b/runtime/engine/kaldi/fstext/fstext-utils.h similarity index 100% rename from speechx/speechx/kaldi/fstext/fstext-utils.h rename to runtime/engine/kaldi/fstext/fstext-utils.h diff --git a/speechx/speechx/kaldi/fstext/kaldi-fst-io-inl.h b/runtime/engine/kaldi/fstext/kaldi-fst-io-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/kaldi-fst-io-inl.h rename to runtime/engine/kaldi/fstext/kaldi-fst-io-inl.h diff --git a/speechx/speechx/kaldi/fstext/kaldi-fst-io.cc b/runtime/engine/kaldi/fstext/kaldi-fst-io.cc similarity index 100% rename from speechx/speechx/kaldi/fstext/kaldi-fst-io.cc rename to runtime/engine/kaldi/fstext/kaldi-fst-io.cc diff --git a/speechx/speechx/kaldi/fstext/kaldi-fst-io.h b/runtime/engine/kaldi/fstext/kaldi-fst-io.h similarity index 100% rename from speechx/speechx/kaldi/fstext/kaldi-fst-io.h rename to runtime/engine/kaldi/fstext/kaldi-fst-io.h diff --git a/speechx/speechx/kaldi/fstext/lattice-utils-inl.h b/runtime/engine/kaldi/fstext/lattice-utils-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/lattice-utils-inl.h rename to runtime/engine/kaldi/fstext/lattice-utils-inl.h diff --git a/speechx/speechx/kaldi/fstext/lattice-utils.h b/runtime/engine/kaldi/fstext/lattice-utils.h similarity index 100% rename from speechx/speechx/kaldi/fstext/lattice-utils.h rename to runtime/engine/kaldi/fstext/lattice-utils.h diff --git a/speechx/speechx/kaldi/fstext/lattice-weight.h b/runtime/engine/kaldi/fstext/lattice-weight.h similarity index 100% rename from speechx/speechx/kaldi/fstext/lattice-weight.h rename to runtime/engine/kaldi/fstext/lattice-weight.h diff --git a/speechx/speechx/kaldi/fstext/pre-determinize-inl.h b/runtime/engine/kaldi/fstext/pre-determinize-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/pre-determinize-inl.h rename to runtime/engine/kaldi/fstext/pre-determinize-inl.h diff --git a/speechx/speechx/kaldi/fstext/pre-determinize.h b/runtime/engine/kaldi/fstext/pre-determinize.h similarity index 100% rename from speechx/speechx/kaldi/fstext/pre-determinize.h rename to runtime/engine/kaldi/fstext/pre-determinize.h diff --git a/speechx/speechx/kaldi/fstext/remove-eps-local-inl.h b/runtime/engine/kaldi/fstext/remove-eps-local-inl.h similarity index 100% rename from speechx/speechx/kaldi/fstext/remove-eps-local-inl.h rename to runtime/engine/kaldi/fstext/remove-eps-local-inl.h diff --git a/speechx/speechx/kaldi/fstext/remove-eps-local.h b/runtime/engine/kaldi/fstext/remove-eps-local.h similarity index 100% rename from speechx/speechx/kaldi/fstext/remove-eps-local.h rename to runtime/engine/kaldi/fstext/remove-eps-local.h diff --git a/speechx/speechx/kaldi/fstext/table-matcher.h b/runtime/engine/kaldi/fstext/table-matcher.h similarity index 100% rename from speechx/speechx/kaldi/fstext/table-matcher.h rename to runtime/engine/kaldi/fstext/table-matcher.h diff --git a/speechx/speechx/kaldi/lat/CMakeLists.txt b/runtime/engine/kaldi/lat/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/lat/CMakeLists.txt rename to runtime/engine/kaldi/lat/CMakeLists.txt diff --git a/speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc b/runtime/engine/kaldi/lat/determinize-lattice-pruned.cc similarity index 100% rename from speechx/speechx/kaldi/lat/determinize-lattice-pruned.cc rename to runtime/engine/kaldi/lat/determinize-lattice-pruned.cc diff --git a/speechx/speechx/kaldi/lat/determinize-lattice-pruned.h b/runtime/engine/kaldi/lat/determinize-lattice-pruned.h similarity index 100% rename from speechx/speechx/kaldi/lat/determinize-lattice-pruned.h rename to runtime/engine/kaldi/lat/determinize-lattice-pruned.h diff --git a/speechx/speechx/kaldi/lat/kaldi-lattice.cc b/runtime/engine/kaldi/lat/kaldi-lattice.cc similarity index 100% rename from speechx/speechx/kaldi/lat/kaldi-lattice.cc rename to runtime/engine/kaldi/lat/kaldi-lattice.cc diff --git a/speechx/speechx/kaldi/lat/kaldi-lattice.h b/runtime/engine/kaldi/lat/kaldi-lattice.h similarity index 100% rename from speechx/speechx/kaldi/lat/kaldi-lattice.h rename to runtime/engine/kaldi/lat/kaldi-lattice.h diff --git a/speechx/speechx/kaldi/lat/lattice-functions.cc b/runtime/engine/kaldi/lat/lattice-functions.cc similarity index 100% rename from speechx/speechx/kaldi/lat/lattice-functions.cc rename to runtime/engine/kaldi/lat/lattice-functions.cc diff --git a/speechx/speechx/kaldi/lat/lattice-functions.h b/runtime/engine/kaldi/lat/lattice-functions.h similarity index 97% rename from speechx/speechx/kaldi/lat/lattice-functions.h rename to runtime/engine/kaldi/lat/lattice-functions.h index 6b1b6656..785d3f96 100644 --- a/speechx/speechx/kaldi/lat/lattice-functions.h +++ b/runtime/engine/kaldi/lat/lattice-functions.h @@ -355,12 +355,12 @@ bool PruneLattice(BaseFloat beam, LatticeType *lat); // // // /// This function returns the number of words in the longest sentence in a -// /// CompactLattice (i.e. the the maximum of any path, of the count of +// /// CompactLattice (i.e. the maximum of any path, of the count of // /// olabels on that path). // int32 LongestSentenceLength(const Lattice &lat); // // /// This function returns the number of words in the longest sentence in a -// /// CompactLattice, i.e. the the maximum of any path, of the count of +// /// CompactLattice, i.e. the maximum of any path, of the count of // /// labels on that path... note, in CompactLattice, the ilabels and olabels // /// are identical because it is an acceptor. // int32 LongestSentenceLength(const CompactLattice &lat); @@ -408,7 +408,7 @@ bool PruneLattice(BaseFloat beam, LatticeType *lat); // // /// This function computes the mapping from the pair // /// (frame-index, transition-id) to the pair -// /// (sum-of-acoustic-scores, num-of-occurences) over all occurences of the +// /// (sum-of-acoustic-scores, num-of-occurrences) over all occurrences of the // /// transition-id in that frame. // /// frame-index in the lattice. // /// This function is useful for retaining the acoustic scores in a @@ -422,13 +422,13 @@ bool PruneLattice(BaseFloat beam, LatticeType *lat); // /// @param [out] acoustic_scores // /// Pointer to a map from the pair (frame-index, // /// transition-id) to a pair (sum-of-acoustic-scores, -// /// num-of-occurences). +// /// num-of-occurrences). // /// Usually the acoustic scores for a pdf-id (and hence // /// transition-id) on a frame will be the same for all the -// /// occurences of the pdf-id in that frame. +// /// occurrences of the pdf-id in that frame. // /// But if not, we will take the average of the acoustic // /// scores. Hence, we store both the sum-of-acoustic-scores -// /// and the num-of-occurences of the transition-id in that +// /// and the num-of-occurrences of the transition-id in that // /// frame. // void ComputeAcousticScoresMap( // const Lattice &lat, @@ -440,8 +440,8 @@ bool PruneLattice(BaseFloat beam, LatticeType *lat); // /// // /// @param [in] acoustic_scores // /// A map from the pair (frame-index, transition-id) to a -// /// pair (sum-of-acoustic-scores, num-of-occurences) of -// /// the occurences of the transition-id in that frame. +// /// pair (sum-of-acoustic-scores, num-of-occurrences) of +// /// the occurrences of the transition-id in that frame. // /// See the comments for ComputeAcousticScoresMap for // /// details. // /// @param [out] lat Pointer to the output lattice. diff --git a/speechx/speechx/kaldi/lm/CMakeLists.txt b/runtime/engine/kaldi/lm/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/lm/CMakeLists.txt rename to runtime/engine/kaldi/lm/CMakeLists.txt diff --git a/speechx/speechx/kaldi/lm/arpa-file-parser.cc b/runtime/engine/kaldi/lm/arpa-file-parser.cc similarity index 100% rename from speechx/speechx/kaldi/lm/arpa-file-parser.cc rename to runtime/engine/kaldi/lm/arpa-file-parser.cc diff --git a/speechx/speechx/kaldi/lm/arpa-file-parser.h b/runtime/engine/kaldi/lm/arpa-file-parser.h similarity index 100% rename from speechx/speechx/kaldi/lm/arpa-file-parser.h rename to runtime/engine/kaldi/lm/arpa-file-parser.h diff --git a/speechx/speechx/kaldi/lm/arpa-lm-compiler.cc b/runtime/engine/kaldi/lm/arpa-lm-compiler.cc similarity index 100% rename from speechx/speechx/kaldi/lm/arpa-lm-compiler.cc rename to runtime/engine/kaldi/lm/arpa-lm-compiler.cc diff --git a/speechx/speechx/kaldi/lm/arpa-lm-compiler.h b/runtime/engine/kaldi/lm/arpa-lm-compiler.h similarity index 100% rename from speechx/speechx/kaldi/lm/arpa-lm-compiler.h rename to runtime/engine/kaldi/lm/arpa-lm-compiler.h diff --git a/speechx/speechx/kaldi/lmbin/CMakeLists.txt b/runtime/engine/kaldi/lmbin/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/lmbin/CMakeLists.txt rename to runtime/engine/kaldi/lmbin/CMakeLists.txt diff --git a/speechx/speechx/kaldi/lmbin/arpa2fst.cc b/runtime/engine/kaldi/lmbin/arpa2fst.cc similarity index 100% rename from speechx/speechx/kaldi/lmbin/arpa2fst.cc rename to runtime/engine/kaldi/lmbin/arpa2fst.cc diff --git a/speechx/speechx/kaldi/util/CMakeLists.txt b/runtime/engine/kaldi/util/CMakeLists.txt similarity index 100% rename from speechx/speechx/kaldi/util/CMakeLists.txt rename to runtime/engine/kaldi/util/CMakeLists.txt diff --git a/speechx/speechx/kaldi/util/basic-filebuf.h b/runtime/engine/kaldi/util/basic-filebuf.h similarity index 100% rename from speechx/speechx/kaldi/util/basic-filebuf.h rename to runtime/engine/kaldi/util/basic-filebuf.h diff --git a/speechx/speechx/kaldi/util/common-utils.h b/runtime/engine/kaldi/util/common-utils.h similarity index 100% rename from speechx/speechx/kaldi/util/common-utils.h rename to runtime/engine/kaldi/util/common-utils.h diff --git a/speechx/speechx/kaldi/util/const-integer-set-inl.h b/runtime/engine/kaldi/util/const-integer-set-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/const-integer-set-inl.h rename to runtime/engine/kaldi/util/const-integer-set-inl.h diff --git a/speechx/speechx/kaldi/util/const-integer-set.h b/runtime/engine/kaldi/util/const-integer-set.h similarity index 100% rename from speechx/speechx/kaldi/util/const-integer-set.h rename to runtime/engine/kaldi/util/const-integer-set.h diff --git a/speechx/speechx/kaldi/util/edit-distance-inl.h b/runtime/engine/kaldi/util/edit-distance-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/edit-distance-inl.h rename to runtime/engine/kaldi/util/edit-distance-inl.h diff --git a/speechx/speechx/kaldi/util/edit-distance.h b/runtime/engine/kaldi/util/edit-distance.h similarity index 100% rename from speechx/speechx/kaldi/util/edit-distance.h rename to runtime/engine/kaldi/util/edit-distance.h diff --git a/speechx/speechx/kaldi/util/hash-list-inl.h b/runtime/engine/kaldi/util/hash-list-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/hash-list-inl.h rename to runtime/engine/kaldi/util/hash-list-inl.h diff --git a/speechx/speechx/kaldi/util/hash-list.h b/runtime/engine/kaldi/util/hash-list.h similarity index 100% rename from speechx/speechx/kaldi/util/hash-list.h rename to runtime/engine/kaldi/util/hash-list.h diff --git a/speechx/speechx/kaldi/util/kaldi-cygwin-io-inl.h b/runtime/engine/kaldi/util/kaldi-cygwin-io-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-cygwin-io-inl.h rename to runtime/engine/kaldi/util/kaldi-cygwin-io-inl.h diff --git a/speechx/speechx/kaldi/util/kaldi-holder-inl.h b/runtime/engine/kaldi/util/kaldi-holder-inl.h similarity index 85% rename from speechx/speechx/kaldi/util/kaldi-holder-inl.h rename to runtime/engine/kaldi/util/kaldi-holder-inl.h index 134cdd93..9b441ad4 100644 --- a/speechx/speechx/kaldi/util/kaldi-holder-inl.h +++ b/runtime/engine/kaldi/util/kaldi-holder-inl.h @@ -754,53 +754,53 @@ class TokenVectorHolder { }; -class HtkMatrixHolder { - public: - typedef std::pair, HtkHeader> T; - - HtkMatrixHolder() {} - - static bool Write(std::ostream &os, bool binary, const T &t) { - if (!binary) - KALDI_ERR << "Non-binary HTK-format write not supported."; - bool ans = WriteHtk(os, t.first, t.second); - if (!ans) - KALDI_WARN << "Error detected writing HTK-format matrix."; - return ans; - } - - void Clear() { t_.first.Resize(0, 0); } - - // Reads into the holder. - bool Read(std::istream &is) { - bool ans = ReadHtk(is, &t_.first, &t_.second); - if (!ans) { - KALDI_WARN << "Error detected reading HTK-format matrix."; - return false; - } - return ans; - } - - // HTK-format matrices only read in binary. - static bool IsReadInBinary() { return true; } - - T &Value() { return t_; } - - void Swap(HtkMatrixHolder *other) { - t_.first.Swap(&(other->t_.first)); - std::swap(t_.second, other->t_.second); - } - - bool ExtractRange(const HtkMatrixHolder &other, - const std::string &range) { - KALDI_ERR << "ExtractRange is not defined for this type of holder."; - return false; - } - // Default destructor. - private: - KALDI_DISALLOW_COPY_AND_ASSIGN(HtkMatrixHolder); - T t_; -}; +//class HtkMatrixHolder { + //public: + //typedef std::pair, HtkHeader> T; + + //HtkMatrixHolder() {} + + //static bool Write(std::ostream &os, bool binary, const T &t) { + //if (!binary) + //KALDI_ERR << "Non-binary HTK-format write not supported."; + //bool ans = WriteHtk(os, t.first, t.second); + //if (!ans) + //KALDI_WARN << "Error detected writing HTK-format matrix."; + //return ans; + //} + + //void Clear() { t_.first.Resize(0, 0); } + + //// Reads into the holder. + //bool Read(std::istream &is) { + //bool ans = ReadHtk(is, &t_.first, &t_.second); + //if (!ans) { + //KALDI_WARN << "Error detected reading HTK-format matrix."; + //return false; + //} + //return ans; + //} + + //// HTK-format matrices only read in binary. + //static bool IsReadInBinary() { return true; } + + //T &Value() { return t_; } + + //void Swap(HtkMatrixHolder *other) { + //t_.first.Swap(&(other->t_.first)); + //std::swap(t_.second, other->t_.second); + //} + + //bool ExtractRange(const HtkMatrixHolder &other, + //const std::string &range) { + //KALDI_ERR << "ExtractRange is not defined for this type of holder."; + //return false; + //} + //// Default destructor. + //private: + //KALDI_DISALLOW_COPY_AND_ASSIGN(HtkMatrixHolder); + //T t_; +//}; // SphinxMatrixHolder can be used to read and write feature files in // CMU Sphinx format. 13-dimensional big-endian features are assumed. @@ -813,104 +813,104 @@ class HtkMatrixHolder { // be no problem, because the usage help of Sphinx' "wave2feat" for example // says that Sphinx features are always big endian. // Note: the kFeatDim defaults to 13, see forward declaration in kaldi-holder.h -template class SphinxMatrixHolder { - public: - typedef Matrix T; - - SphinxMatrixHolder() {} - - void Clear() { feats_.Resize(0, 0); } - - // Writes Sphinx-format features - static bool Write(std::ostream &os, bool binary, const T &m) { - if (!binary) { - KALDI_WARN << "SphinxMatrixHolder can't write Sphinx features in text "; - return false; - } - - int32 size = m.NumRows() * m.NumCols(); - if (MachineIsLittleEndian()) - KALDI_SWAP4(size); - // write the header - os.write(reinterpret_cast (&size), sizeof(size)); - - for (MatrixIndexT i = 0; i < m.NumRows(); i++) { - std::vector tmp(m.NumCols()); - for (MatrixIndexT j = 0; j < m.NumCols(); j++) { - tmp[j] = static_cast(m(i, j)); - if (MachineIsLittleEndian()) - KALDI_SWAP4(tmp[j]); - } - os.write(reinterpret_cast(&(tmp[0])), - tmp.size() * 4); - } - return true; - } - - // Reads the features into a Kaldi Matrix - bool Read(std::istream &is) { - int32 nmfcc; - - is.read(reinterpret_cast (&nmfcc), sizeof(nmfcc)); - if (MachineIsLittleEndian()) - KALDI_SWAP4(nmfcc); - KALDI_VLOG(2) << "#feats: " << nmfcc; - int32 nfvec = nmfcc / kFeatDim; - if ((nmfcc % kFeatDim) != 0) { - KALDI_WARN << "Sphinx feature count is inconsistent with vector length "; - return false; - } - - feats_.Resize(nfvec, kFeatDim); - for (MatrixIndexT i = 0; i < feats_.NumRows(); i++) { - if (sizeof(BaseFloat) == sizeof(float32)) { - is.read(reinterpret_cast (feats_.RowData(i)), - kFeatDim * sizeof(float32)); - if (!is.good()) { - KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; - return false; - } - if (MachineIsLittleEndian()) { - for (MatrixIndexT j = 0; j < kFeatDim; j++) - KALDI_SWAP4(feats_(i, j)); - } - } else { // KALDI_DOUBLEPRECISION=1 - float32 tmp[kFeatDim]; - is.read(reinterpret_cast (tmp), sizeof(tmp)); - if (!is.good()) { - KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; - return false; - } - for (MatrixIndexT j = 0; j < kFeatDim; j++) { - if (MachineIsLittleEndian()) - KALDI_SWAP4(tmp[j]); - feats_(i, j) = static_cast(tmp[j]); - } - } - } - - return true; - } - - // Only read in binary - static bool IsReadInBinary() { return true; } - - T &Value() { return feats_; } - - void Swap(SphinxMatrixHolder *other) { - feats_.Swap(&(other->feats_)); - } - - bool ExtractRange(const SphinxMatrixHolder &other, - const std::string &range) { - KALDI_ERR << "ExtractRange is not defined for this type of holder."; - return false; - } - - private: - KALDI_DISALLOW_COPY_AND_ASSIGN(SphinxMatrixHolder); - T feats_; -}; +//template class SphinxMatrixHolder { + //public: + //typedef Matrix T; + + //SphinxMatrixHolder() {} + + //void Clear() { feats_.Resize(0, 0); } + + //// Writes Sphinx-format features + //static bool Write(std::ostream &os, bool binary, const T &m) { + //if (!binary) { + //KALDI_WARN << "SphinxMatrixHolder can't write Sphinx features in text "; + //return false; + //} + + //int32 size = m.NumRows() * m.NumCols(); + //if (MachineIsLittleEndian()) + //KALDI_SWAP4(size); + //// write the header + //os.write(reinterpret_cast (&size), sizeof(size)); + + //for (MatrixIndexT i = 0; i < m.NumRows(); i++) { + //std::vector tmp(m.NumCols()); + //for (MatrixIndexT j = 0; j < m.NumCols(); j++) { + //tmp[j] = static_cast(m(i, j)); + //if (MachineIsLittleEndian()) + //KALDI_SWAP4(tmp[j]); + //} + //os.write(reinterpret_cast(&(tmp[0])), + //tmp.size() * 4); + //} + //return true; + //} + + //// Reads the features into a Kaldi Matrix + //bool Read(std::istream &is) { + //int32 nmfcc; + + //is.read(reinterpret_cast (&nmfcc), sizeof(nmfcc)); + //if (MachineIsLittleEndian()) + //KALDI_SWAP4(nmfcc); + //KALDI_VLOG(2) << "#feats: " << nmfcc; + //int32 nfvec = nmfcc / kFeatDim; + //if ((nmfcc % kFeatDim) != 0) { + //KALDI_WARN << "Sphinx feature count is inconsistent with vector length "; + //return false; + //} + + //feats_.Resize(nfvec, kFeatDim); + //for (MatrixIndexT i = 0; i < feats_.NumRows(); i++) { + //if (sizeof(BaseFloat) == sizeof(float32)) { + //is.read(reinterpret_cast (feats_.RowData(i)), + //kFeatDim * sizeof(float32)); + //if (!is.good()) { + //KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; + //return false; + //} + //if (MachineIsLittleEndian()) { + //for (MatrixIndexT j = 0; j < kFeatDim; j++) + //KALDI_SWAP4(feats_(i, j)); + //} + //} else { // KALDI_DOUBLEPRECISION=1 + //float32 tmp[kFeatDim]; + //is.read(reinterpret_cast (tmp), sizeof(tmp)); + //if (!is.good()) { + //KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; + //return false; + //} + //for (MatrixIndexT j = 0; j < kFeatDim; j++) { + //if (MachineIsLittleEndian()) + //KALDI_SWAP4(tmp[j]); + //feats_(i, j) = static_cast(tmp[j]); + //} + //} + //} + + //return true; + //} + + //// Only read in binary + //static bool IsReadInBinary() { return true; } + + //T &Value() { return feats_; } + + //void Swap(SphinxMatrixHolder *other) { + //feats_.Swap(&(other->feats_)); + //} + + //bool ExtractRange(const SphinxMatrixHolder &other, + //const std::string &range) { + //KALDI_ERR << "ExtractRange is not defined for this type of holder."; + //return false; + //} + + //private: + //KALDI_DISALLOW_COPY_AND_ASSIGN(SphinxMatrixHolder); + //T feats_; +//}; /// @} end "addtogroup holders" diff --git a/speechx/speechx/kaldi/util/kaldi-holder.cc b/runtime/engine/kaldi/util/kaldi-holder.cc similarity index 99% rename from speechx/speechx/kaldi/util/kaldi-holder.cc rename to runtime/engine/kaldi/util/kaldi-holder.cc index 577679ef..6b0eebb9 100644 --- a/speechx/speechx/kaldi/util/kaldi-holder.cc +++ b/runtime/engine/kaldi/util/kaldi-holder.cc @@ -85,7 +85,7 @@ bool ParseMatrixRangeSpecifier(const std::string &range, return status; } -bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, +/*bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, GeneralMatrix *output) { // We just inspect input's type and forward to the correct implementation // if available. For kSparseMatrix, we do just fairly inefficient conversion @@ -135,6 +135,7 @@ template bool ExtractObjectRange(const CompressedMatrix &, const std::string &, template bool ExtractObjectRange(const CompressedMatrix &, const std::string &, Matrix *); +*/ template bool ExtractObjectRange(const Matrix &input, const std::string &range, Matrix *output) { diff --git a/speechx/speechx/kaldi/util/kaldi-holder.h b/runtime/engine/kaldi/util/kaldi-holder.h similarity index 96% rename from speechx/speechx/kaldi/util/kaldi-holder.h rename to runtime/engine/kaldi/util/kaldi-holder.h index f495f27f..a8c42c9f 100644 --- a/speechx/speechx/kaldi/util/kaldi-holder.h +++ b/runtime/engine/kaldi/util/kaldi-holder.h @@ -27,7 +27,6 @@ #include "util/kaldi-io.h" #include "util/text-utils.h" #include "matrix/kaldi-vector.h" -#include "matrix/sparse-matrix.h" namespace kaldi { @@ -214,10 +213,10 @@ class TokenVectorHolder; /// A class for reading/writing HTK-format matrices. /// T == std::pair, HtkHeader> -class HtkMatrixHolder; +//class HtkMatrixHolder; /// A class for reading/writing Sphinx format matrices. -template class SphinxMatrixHolder; +//template class SphinxMatrixHolder; /// This templated function exists so that we can write .scp files with /// 'object ranges' specified: the canonical example is a [first:last] range @@ -249,15 +248,15 @@ bool ExtractObjectRange(const Vector &input, const std::string &range, Vector *output); /// GeneralMatrix is always of type BaseFloat -bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, - GeneralMatrix *output); +//bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, + // GeneralMatrix *output); /// CompressedMatrix is always of the type BaseFloat but it is more /// efficient to provide template as it uses CompressedMatrix's own /// conversion to Matrix -template -bool ExtractObjectRange(const CompressedMatrix &input, const std::string &range, - Matrix *output); +//template +//bool ExtractObjectRange(const CompressedMatrix &input, const std::string &range, + // Matrix *output); // In SequentialTableReaderScriptImpl and RandomAccessTableReaderScriptImpl, for // cases where the scp contained 'range specifiers' (things in square brackets diff --git a/speechx/speechx/kaldi/util/kaldi-io-inl.h b/runtime/engine/kaldi/util/kaldi-io-inl.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-io-inl.h rename to runtime/engine/kaldi/util/kaldi-io-inl.h diff --git a/speechx/speechx/kaldi/util/kaldi-io.cc b/runtime/engine/kaldi/util/kaldi-io.cc similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-io.cc rename to runtime/engine/kaldi/util/kaldi-io.cc diff --git a/speechx/speechx/kaldi/util/kaldi-io.h b/runtime/engine/kaldi/util/kaldi-io.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-io.h rename to runtime/engine/kaldi/util/kaldi-io.h diff --git a/speechx/speechx/kaldi/util/kaldi-pipebuf.h b/runtime/engine/kaldi/util/kaldi-pipebuf.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-pipebuf.h rename to runtime/engine/kaldi/util/kaldi-pipebuf.h diff --git a/speechx/speechx/kaldi/util/kaldi-semaphore.cc b/runtime/engine/kaldi/util/kaldi-semaphore.cc similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-semaphore.cc rename to runtime/engine/kaldi/util/kaldi-semaphore.cc diff --git a/speechx/speechx/kaldi/util/kaldi-semaphore.h b/runtime/engine/kaldi/util/kaldi-semaphore.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-semaphore.h rename to runtime/engine/kaldi/util/kaldi-semaphore.h diff --git a/speechx/speechx/kaldi/util/kaldi-table-inl.h b/runtime/engine/kaldi/util/kaldi-table-inl.h similarity index 99% rename from speechx/speechx/kaldi/util/kaldi-table-inl.h rename to runtime/engine/kaldi/util/kaldi-table-inl.h index 6aca2f13..175e2704 100644 --- a/speechx/speechx/kaldi/util/kaldi-table-inl.h +++ b/runtime/engine/kaldi/util/kaldi-table-inl.h @@ -1587,7 +1587,7 @@ template class RandomAccessTableReaderImplBase { // this from a pipe. In principle we could read it on-demand as for the // archives, but this would probably be overkill. -// Note: the code for this this class is similar to TableWriterScriptImpl: +// Note: the code for this class is similar to TableWriterScriptImpl: // try to keep them in sync. template class RandomAccessTableReaderScriptImpl: diff --git a/speechx/speechx/kaldi/util/kaldi-table.cc b/runtime/engine/kaldi/util/kaldi-table.cc similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-table.cc rename to runtime/engine/kaldi/util/kaldi-table.cc diff --git a/speechx/speechx/kaldi/util/kaldi-table.h b/runtime/engine/kaldi/util/kaldi-table.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-table.h rename to runtime/engine/kaldi/util/kaldi-table.h diff --git a/speechx/speechx/kaldi/util/kaldi-thread.cc b/runtime/engine/kaldi/util/kaldi-thread.cc similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-thread.cc rename to runtime/engine/kaldi/util/kaldi-thread.cc diff --git a/speechx/speechx/kaldi/util/kaldi-thread.h b/runtime/engine/kaldi/util/kaldi-thread.h similarity index 100% rename from speechx/speechx/kaldi/util/kaldi-thread.h rename to runtime/engine/kaldi/util/kaldi-thread.h diff --git a/speechx/speechx/kaldi/util/options-itf.h b/runtime/engine/kaldi/util/options-itf.h similarity index 100% rename from speechx/speechx/kaldi/util/options-itf.h rename to runtime/engine/kaldi/util/options-itf.h diff --git a/speechx/speechx/kaldi/util/parse-options.cc b/runtime/engine/kaldi/util/parse-options.cc similarity index 100% rename from speechx/speechx/kaldi/util/parse-options.cc rename to runtime/engine/kaldi/util/parse-options.cc diff --git a/speechx/speechx/kaldi/util/parse-options.h b/runtime/engine/kaldi/util/parse-options.h similarity index 100% rename from speechx/speechx/kaldi/util/parse-options.h rename to runtime/engine/kaldi/util/parse-options.h diff --git a/speechx/speechx/kaldi/util/simple-io-funcs.cc b/runtime/engine/kaldi/util/simple-io-funcs.cc similarity index 100% rename from speechx/speechx/kaldi/util/simple-io-funcs.cc rename to runtime/engine/kaldi/util/simple-io-funcs.cc diff --git a/speechx/speechx/kaldi/util/simple-io-funcs.h b/runtime/engine/kaldi/util/simple-io-funcs.h similarity index 100% rename from speechx/speechx/kaldi/util/simple-io-funcs.h rename to runtime/engine/kaldi/util/simple-io-funcs.h diff --git a/speechx/speechx/kaldi/util/simple-options.cc b/runtime/engine/kaldi/util/simple-options.cc similarity index 100% rename from speechx/speechx/kaldi/util/simple-options.cc rename to runtime/engine/kaldi/util/simple-options.cc diff --git a/speechx/speechx/kaldi/util/simple-options.h b/runtime/engine/kaldi/util/simple-options.h similarity index 100% rename from speechx/speechx/kaldi/util/simple-options.h rename to runtime/engine/kaldi/util/simple-options.h diff --git a/speechx/speechx/kaldi/util/stl-utils.h b/runtime/engine/kaldi/util/stl-utils.h similarity index 100% rename from speechx/speechx/kaldi/util/stl-utils.h rename to runtime/engine/kaldi/util/stl-utils.h diff --git a/speechx/speechx/kaldi/util/table-types.h b/runtime/engine/kaldi/util/table-types.h similarity index 69% rename from speechx/speechx/kaldi/util/table-types.h rename to runtime/engine/kaldi/util/table-types.h index efcdf1b5..665a1327 100644 --- a/speechx/speechx/kaldi/util/table-types.h +++ b/runtime/engine/kaldi/util/table-types.h @@ -23,7 +23,8 @@ #include "base/kaldi-common.h" #include "util/kaldi-table.h" #include "util/kaldi-holder.h" -#include "matrix/matrix-lib.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/kaldi-vector.h" namespace kaldi { @@ -51,8 +52,8 @@ typedef RandomAccessTableReader > > typedef RandomAccessTableReaderMapped > > RandomAccessDoubleMatrixReaderMapped; -typedef TableWriter > - CompressedMatrixWriter; +//typedef TableWriter > + //CompressedMatrixWriter; typedef TableWriter > > BaseFloatVectorWriter; @@ -70,39 +71,39 @@ typedef SequentialTableReader > > typedef RandomAccessTableReader > > RandomAccessDoubleVectorReader; -typedef TableWriter > > - BaseFloatCuMatrixWriter; -typedef SequentialTableReader > > - SequentialBaseFloatCuMatrixReader; -typedef RandomAccessTableReader > > - RandomAccessBaseFloatCuMatrixReader; -typedef RandomAccessTableReaderMapped > > - RandomAccessBaseFloatCuMatrixReaderMapped; - -typedef TableWriter > > - DoubleCuMatrixWriter; -typedef SequentialTableReader > > - SequentialDoubleCuMatrixReader; -typedef RandomAccessTableReader > > - RandomAccessDoubleCuMatrixReader; -typedef RandomAccessTableReaderMapped > > - RandomAccessDoubleCuMatrixReaderMapped; - -typedef TableWriter > > - BaseFloatCuVectorWriter; -typedef SequentialTableReader > > - SequentialBaseFloatCuVectorReader; -typedef RandomAccessTableReader > > - RandomAccessBaseFloatCuVectorReader; -typedef RandomAccessTableReaderMapped > > - RandomAccessBaseFloatCuVectorReaderMapped; - -typedef TableWriter > > - DoubleCuVectorWriter; -typedef SequentialTableReader > > - SequentialDoubleCuVectorReader; -typedef RandomAccessTableReader > > - RandomAccessDoubleCuVectorReader; +//typedef TableWriter > > + //BaseFloatCuMatrixWriter; +//typedef SequentialTableReader > > + //SequentialBaseFloatCuMatrixReader; +//typedef RandomAccessTableReader > > + //RandomAccessBaseFloatCuMatrixReader; +//typedef RandomAccessTableReaderMapped > > + //RandomAccessBaseFloatCuMatrixReaderMapped; + +//typedef TableWriter > > + //DoubleCuMatrixWriter; +//typedef SequentialTableReader > > + //SequentialDoubleCuMatrixReader; +//typedef RandomAccessTableReader > > + //RandomAccessDoubleCuMatrixReader; +//typedef RandomAccessTableReaderMapped > > + //RandomAccessDoubleCuMatrixReaderMapped; + +//typedef TableWriter > > + //BaseFloatCuVectorWriter; +//typedef SequentialTableReader > > + //SequentialBaseFloatCuVectorReader; +//typedef RandomAccessTableReader > > + //RandomAccessBaseFloatCuVectorReader; +//typedef RandomAccessTableReaderMapped > > + //RandomAccessBaseFloatCuVectorReaderMapped; + +//typedef TableWriter > > + //DoubleCuVectorWriter; +//typedef SequentialTableReader > > + //SequentialDoubleCuVectorReader; +//typedef RandomAccessTableReader > > + //RandomAccessDoubleCuVectorReader; typedef TableWriter > Int32Writer; @@ -150,8 +151,6 @@ typedef TableWriter > BoolWriter; typedef SequentialTableReader > SequentialBoolReader; typedef RandomAccessTableReader > RandomAccessBoolReader; - - /// TokenWriter is a writer specialized for std::string where the strings /// are nonempty and whitespace-free. T == std::string typedef TableWriter TokenWriter; @@ -169,14 +168,14 @@ typedef RandomAccessTableReader RandomAccessTokenVectorReader; -typedef TableWriter > - GeneralMatrixWriter; -typedef SequentialTableReader > - SequentialGeneralMatrixReader; -typedef RandomAccessTableReader > - RandomAccessGeneralMatrixReader; -typedef RandomAccessTableReaderMapped > - RandomAccessGeneralMatrixReaderMapped; +//typedef TableWriter > +// GeneralMatrixWriter; +//typedef SequentialTableReader > + // SequentialGeneralMatrixReader; +//typedef RandomAccessTableReader > + // RandomAccessGeneralMatrixReader; +//typedef RandomAccessTableReaderMapped > + // RandomAccessGeneralMatrixReaderMapped; diff --git a/speechx/speechx/kaldi/util/text-utils.cc b/runtime/engine/kaldi/util/text-utils.cc similarity index 100% rename from speechx/speechx/kaldi/util/text-utils.cc rename to runtime/engine/kaldi/util/text-utils.cc diff --git a/speechx/speechx/kaldi/util/text-utils.h b/runtime/engine/kaldi/util/text-utils.h similarity index 100% rename from speechx/speechx/kaldi/util/text-utils.h rename to runtime/engine/kaldi/util/text-utils.h diff --git a/runtime/engine/vad/CMakeLists.txt b/runtime/engine/vad/CMakeLists.txt new file mode 100644 index 00000000..442acbd8 --- /dev/null +++ b/runtime/engine/vad/CMakeLists.txt @@ -0,0 +1,5 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +add_subdirectory(nnet) + +add_subdirectory(interface) \ No newline at end of file diff --git a/runtime/engine/vad/frontend/wav.h b/runtime/engine/vad/frontend/wav.h new file mode 100644 index 00000000..f9b7bee2 --- /dev/null +++ b/runtime/engine/vad/frontend/wav.h @@ -0,0 +1,199 @@ +// Copyright (c) 2016 Personal (Binbin Zhang) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include + +#include +#include + +namespace wav { + +struct WavHeader { + char riff[4]; // "riff" + unsigned int size; + char wav[4]; // "WAVE" + char fmt[4]; // "fmt " + unsigned int fmt_size; + uint16_t format; + uint16_t channels; + unsigned int sample_rate; + unsigned int bytes_per_second; + uint16_t block_size; + uint16_t bit; + char data[4]; // "data" + unsigned int data_size; +}; + +class WavReader { + public: + WavReader() : data_(nullptr) {} + explicit WavReader(const std::string& filename) { Open(filename); } + + bool Open(const std::string& filename) { + FILE* fp = fopen(filename.c_str(), "rb"); + if (NULL == fp) { + std::cout << "Error in read " << filename; + return false; + } + + WavHeader header; + fread(&header, 1, sizeof(header), fp); + if (header.fmt_size < 16) { + fprintf(stderr, + "WaveData: expect PCM format data " + "to have fmt chunk of at least size 16.\n"); + return false; + } else if (header.fmt_size > 16) { + int offset = 44 - 8 + header.fmt_size - 16; + fseek(fp, offset, SEEK_SET); + fread(header.data, 8, sizeof(char), fp); + } + // check "riff" "WAVE" "fmt " "data" + + // Skip any sub-chunks between "fmt" and "data". Usually there will + // be a single "fact" sub chunk, but on Windows there can also be a + // "list" sub chunk. + while (0 != strncmp(header.data, "data", 4)) { + // We will just ignore the data in these chunks. + fseek(fp, header.data_size, SEEK_CUR); + // read next sub chunk + fread(header.data, 8, sizeof(char), fp); + } + + num_channel_ = header.channels; + sample_rate_ = header.sample_rate; + bits_per_sample_ = header.bit; + int num_data = header.data_size / (bits_per_sample_ / 8); + data_ = new float[num_data]; // Create 1-dim array + num_samples_ = num_data / num_channel_; + + for (int i = 0; i < num_data; ++i) { + switch (bits_per_sample_) { + case 8: { + char sample; + fread(&sample, 1, sizeof(char), fp); + data_[i] = static_cast(sample); + break; + } + case 16: { + int16_t sample; + fread(&sample, 1, sizeof(int16_t), fp); + // std::cout << sample; + data_[i] = static_cast(sample); + // std::cout << data_[i]; + break; + } + case 32: { + int sample; + fread(&sample, 1, sizeof(int), fp); + data_[i] = static_cast(sample); + break; + } + default: + fprintf(stderr, "unsupported quantization bits"); + exit(1); + } + } + fclose(fp); + return true; + } + + int num_channel() const { return num_channel_; } + int sample_rate() const { return sample_rate_; } + int bits_per_sample() const { return bits_per_sample_; } + int num_samples() const { return num_samples_; } + const float* data() const { return data_; } + + private: + int num_channel_; + int sample_rate_; + int bits_per_sample_; + int num_samples_; // sample points per channel + float* data_; +}; + +class WavWriter { + public: + WavWriter(const float* data, + int num_samples, + int num_channel, + int sample_rate, + int bits_per_sample) + : data_(data), + num_samples_(num_samples), + num_channel_(num_channel), + sample_rate_(sample_rate), + bits_per_sample_(bits_per_sample) {} + + void Write(const std::string& filename) { + FILE* fp = fopen(filename.c_str(), "w"); + // init char 'riff' 'WAVE' 'fmt ' 'data' + WavHeader header; + char wav_header[44] = { + 0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, + 0x45, 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x64, 0x61, 0x74, 0x61, 0x00, 0x00, 0x00, 0x00}; + memcpy(&header, wav_header, sizeof(header)); + header.channels = num_channel_; + header.bit = bits_per_sample_; + header.sample_rate = sample_rate_; + header.data_size = num_samples_ * num_channel_ * (bits_per_sample_ / 8); + header.size = sizeof(header) - 8 + header.data_size; + header.bytes_per_second = + sample_rate_ * num_channel_ * (bits_per_sample_ / 8); + header.block_size = num_channel_ * (bits_per_sample_ / 8); + + fwrite(&header, 1, sizeof(header), fp); + + for (int i = 0; i < num_samples_; ++i) { + for (int j = 0; j < num_channel_; ++j) { + switch (bits_per_sample_) { + case 8: { + char sample = + static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + case 16: { + int16_t sample = + static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + case 32: { + int sample = + static_cast(data_[i * num_channel_ + j]); + fwrite(&sample, 1, sizeof(sample), fp); + break; + } + } + } + } + fclose(fp); + } + + private: + const float* data_; + int num_samples_; // total float points in data_ + int num_channel_; + int sample_rate_; + int bits_per_sample_; +}; + +} // namespace wav diff --git a/runtime/engine/vad/interface/CMakeLists.txt b/runtime/engine/vad/interface/CMakeLists.txt new file mode 100644 index 00000000..e056ec39 --- /dev/null +++ b/runtime/engine/vad/interface/CMakeLists.txt @@ -0,0 +1,24 @@ +set(srcs + vad_interface.cc +) + +add_library(pps_vad_interface SHARED ${srcs}) +target_link_libraries(pps_vad_interface PUBLIC pps_vad extern_glog) + + +set(bin_name vad_interface_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_link_libraries(${bin_name} pps_vad_interface) +# set_target_properties(${bin_name} PROPERTIES PUBLIC_HEADER "vad_interface.h;../frontend/wav.h") + +file(RELATIVE_PATH DEST_DIR ${ENGINE_ROOT} ${CMAKE_CURRENT_SOURCE_DIR}) +install(TARGETS pps_vad_interface DESTINATION lib) +install(FILES vad_interface.h DESTINATION include/${DEST_DIR}) + +install(TARGETS vad_interface_main + RUNTIME DESTINATION bin + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + PUBLIC_HEADER DESTINATION include/${DEST_DIR} +) +install(FILES vad_interface_main.cc DESTINATION demo/${DEST_DIR}) \ No newline at end of file diff --git a/runtime/engine/vad/interface/vad_interface.cc b/runtime/engine/vad/interface/vad_interface.cc new file mode 100644 index 00000000..2e5c9175 --- /dev/null +++ b/runtime/engine/vad/interface/vad_interface.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "vad/interface/vad_interface.h" + +#include "common/base/config.h" +#include "vad/nnet/vad.h" + + +PPSHandle_t PPSVadCreateInstance(const char* conf_path) { + Config conf(conf_path); + ppspeech::VadNnetConf nnet_conf; + nnet_conf.sr = conf.Read("sr", 16000); + nnet_conf.frame_ms = conf.Read("frame_ms", 32); + nnet_conf.threshold = conf.Read("threshold", 0.45f); + nnet_conf.beam = conf.Read("beam", 0.15f); + nnet_conf.min_silence_duration_ms = + conf.Read("min_silence_duration_ms", 200); + nnet_conf.speech_pad_left_ms = conf.Read("speech_pad_left_ms", 0); + nnet_conf.speech_pad_right_ms = conf.Read("speech_pad_right_ms", 0); + + nnet_conf.model_file_path = conf.Read("model_path", std::string("")); + nnet_conf.param_file_path = conf.Read("param_path", std::string("")); + nnet_conf.num_cpu_thread = conf.Read("num_cpu_thread", 1); + + ppspeech::Vad* model = new ppspeech::Vad(nnet_conf.model_file_path); + + // custom config, but must be set before init + model->SetConfig(nnet_conf); + model->Init(); + + return static_cast(model); +} + + +int PPSVadDestroyInstance(PPSHandle_t instance) { + ppspeech::Vad* model = static_cast(instance); + if (model != nullptr) { + delete model; + model = nullptr; + } + return 0; +} + +int PPSVadChunkSizeSamples(PPSHandle_t instance) { + ppspeech::Vad* model = static_cast(instance); + if (model == nullptr) { + printf("instance is null\n"); + return -1; + } + + return model->WindowSizeSamples(); +} + +PPSVadState_t PPSVadFeedForward(PPSHandle_t instance, + float* chunk, + int num_element) { + ppspeech::Vad* model = static_cast(instance); + if (model == nullptr) { + printf("instance is null\n"); + return PPS_VAD_ILLEGAL; + } + + std::vector chunk_in(chunk, chunk + num_element); + if (!model->ForwardChunk(chunk_in)) { + printf("forward chunk failed\n"); + return PPS_VAD_ILLEGAL; + } + ppspeech::Vad::State s = model->Postprocess(); + PPSVadState_t ret = (PPSVadState_t)s; + return ret; +} + +int PPSVadReset(PPSHandle_t instance) { + ppspeech::Vad* model = static_cast(instance); + if (model == nullptr) { + printf("instance is null\n"); + return -1; + } + model->Reset(); + return 0; +} + +int PPSVadGetResult(PPSHandle_t instance, char* result, int max_len){ + ppspeech::Vad* model = static_cast(instance); + if (model == nullptr) { + printf("instance is null\n"); + return -1; + } + return model->GetResult(result, max_len); +}; \ No newline at end of file diff --git a/runtime/engine/vad/interface/vad_interface.h b/runtime/engine/vad/interface/vad_interface.h new file mode 100644 index 00000000..15d0b811 --- /dev/null +++ b/runtime/engine/vad/interface/vad_interface.h @@ -0,0 +1,47 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void* PPSHandle_t; + +typedef enum { + PPS_VAD_ILLEGAL = 0, // error + PPS_VAD_SIL, // silence + PPS_VAD_START, // start speech + PPS_VAD_SPEECH, // in speech + PPS_VAD_END, // end speech + PPS_VAD_NUMSTATES, // number of states +} PPSVadState_t; + +PPSHandle_t PPSVadCreateInstance(const char* conf_path); + +int PPSVadDestroyInstance(PPSHandle_t instance); + +int PPSVadReset(PPSHandle_t instance); + +int PPSVadChunkSizeSamples(PPSHandle_t instance); + +PPSVadState_t PPSVadFeedForward(PPSHandle_t instance, + float* chunk, + int num_element); + +int PPSVadGetResult(PPSHandle_t instance, char* result, int max_len); +#ifdef __cplusplus +} +#endif // __cplusplus \ No newline at end of file diff --git a/runtime/engine/vad/interface/vad_interface_main.cc b/runtime/engine/vad/interface/vad_interface_main.cc new file mode 100644 index 00000000..6dba794d --- /dev/null +++ b/runtime/engine/vad/interface/vad_interface_main.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include + +#include +#include "common/base/common.h" +#include "vad/frontend/wav.h" +#include "vad/interface/vad_interface.h" + +int main(int argc, char* argv[]) { + if (argc < 3) { + std::cout << "Usage: vad_interface_main path/to/config wav.scp" + "run_option, " + "e.g ./vad_interface_main config wav.scp" + << std::endl; + return -1; + } + + std::string config_path = argv[1]; + std::string wav_scp = argv[2]; + + PPSHandle_t handle = PPSVadCreateInstance(config_path.c_str()); + + std::ifstream fp_wav(wav_scp); + std::string line = ""; + while(getline(fp_wav, line)){ + std::vector inputWav; // [0, 1] + wav::WavReader wav_reader = wav::WavReader(line); + auto sr = wav_reader.sample_rate(); + CHECK(sr == 16000) << " sr is " << sr << " expect 16000"; + + auto num_samples = wav_reader.num_samples(); + inputWav.resize(num_samples); + for (int i = 0; i < num_samples; i++) { + inputWav[i] = wav_reader.data()[i] / 32768; + } + + ppspeech::Timer timer; + int window_size_samples = PPSVadChunkSizeSamples(handle); + for (int64_t j = 0; j < num_samples; j += window_size_samples) { + auto start = j; + auto end = start + window_size_samples >= num_samples + ? num_samples + : start + window_size_samples; + std::vector r(window_size_samples, 0); + auto current_chunk_size = end - start; + memcpy(r.data(), inputWav.data() + start, current_chunk_size * sizeof(float)); + + PPSVadState_t s = PPSVadFeedForward(handle, r.data(), r.size()); + } + + std::cout << "RTF=" << timer.Elapsed() / double(num_samples / sr) + << std::endl; + + char result[10240] = {0}; + PPSVadGetResult(handle, result, 10240); + std::cout << line << " " << result << std::endl; + + PPSVadReset(handle); + // getchar(); + } + PPSVadDestroyInstance(handle); + return 0; +} diff --git a/runtime/engine/vad/nnet/CMakeLists.txt b/runtime/engine/vad/nnet/CMakeLists.txt new file mode 100644 index 00000000..3ca951d9 --- /dev/null +++ b/runtime/engine/vad/nnet/CMakeLists.txt @@ -0,0 +1,19 @@ +set(srcs + vad.cc +) + +add_library(pps_vad ${srcs}) +target_link_libraries(pps_vad PUBLIC ${FASTDEPLOY_LIBS} common extern_glog) + + +set(bin_name vad_nnet_main) +add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc) +target_link_libraries(${bin_name} pps_vad) + +file(RELATIVE_PATH DEST_DIR ${ENGINE_ROOT} ${CMAKE_CURRENT_SOURCE_DIR}) +install(TARGETS pps_vad DESTINATION lib) +if(ANDROID) + install(TARGETS extern_glog DESTINATION lib) +else() # UNIX + install(TARGETS glog DESTINATION lib) +endif() diff --git a/runtime/engine/vad/nnet/vad.cc b/runtime/engine/vad/nnet/vad.cc new file mode 100644 index 00000000..101f2370 --- /dev/null +++ b/runtime/engine/vad/nnet/vad.cc @@ -0,0 +1,333 @@ +// Copyright (c) 2023 Chen Qianhe Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "vad/nnet/vad.h" + +#include +#include + +#include "common/base/common.h" + + +namespace ppspeech { + +Vad::Vad(const std::string& model_file, + const fastdeploy::RuntimeOption& + custom_option /* = fastdeploy::RuntimeOption() */) { + valid_cpu_backends = {fastdeploy::Backend::ORT, + fastdeploy::Backend::OPENVINO}; + valid_gpu_backends = {fastdeploy::Backend::ORT, fastdeploy::Backend::TRT}; + + runtime_option = custom_option; + // ORT backend + runtime_option.UseCpu(); + runtime_option.UseOrtBackend(); + runtime_option.model_format = fastdeploy::ModelFormat::ONNX; + // grap opt level + runtime_option.ort_option.graph_optimization_level = 99; + // one-thread + runtime_option.ort_option.intra_op_num_threads = 1; + runtime_option.ort_option.inter_op_num_threads = 1; + // model path + runtime_option.model_file = model_file; +} + +void Vad::Init() { + std::lock_guard lock(init_lock_); + Initialize(); +} + +std::string Vad::ModelName() const { return "VAD"; } + +void Vad::SetConfig(const VadNnetConf conf) { + SetConfig(conf.sr, + conf.frame_ms, + conf.threshold, + conf.beam, + conf.min_silence_duration_ms, + conf.speech_pad_left_ms, + conf.speech_pad_right_ms); +} + +void Vad::SetConfig(const int& sr, + const int& frame_ms, + const float& threshold, + const float& beam, + const int& min_silence_duration_ms, + const int& speech_pad_left_ms, + const int& speech_pad_right_ms) { + if (initialized_) { + fastdeploy::FDERROR << "SetConfig must be called before init" + << std::endl; + throw std::runtime_error("SetConfig must be called before init"); + } + sample_rate_ = sr; + sr_per_ms_ = sr / 1000; + threshold_ = threshold; + beam_ = beam; + frame_ms_ = frame_ms; + min_silence_samples_ = min_silence_duration_ms * sr_per_ms_; + speech_pad_left_samples_ = speech_pad_left_ms * sr_per_ms_; + speech_pad_right_samples_ = speech_pad_right_ms * sr_per_ms_; + + // init chunk size + window_size_samples_ = frame_ms * sr_per_ms_; + current_chunk_size_ = window_size_samples_; + + fastdeploy::FDINFO << "sr=" << sr_per_ms_ << " threshold=" << threshold_ + << " beam=" << beam_ << " frame_ms=" << frame_ms_ + << " min_silence_duration_ms=" << min_silence_duration_ms + << " speech_pad_left_ms=" << speech_pad_left_ms + << " speech_pad_right_ms=" << speech_pad_right_ms; +} + +void Vad::Reset() { + std::memset(h_.data(), 0.0f, h_.size() * sizeof(float)); + std::memset(c_.data(), 0.0f, c_.size() * sizeof(float)); + + triggerd_ = false; + temp_end_ = 0; + current_sample_ = 0; + + speechStart_.clear(); + speechEnd_.clear(); + + states_.clear(); +} + +bool Vad::Initialize() { + // input & output holder + inputTensors_.resize(4); + outputTensors_.resize(3); + + // input shape + input_node_dims_.emplace_back(1); + input_node_dims_.emplace_back(window_size_samples_); + // sr buffer + sr_.resize(1); + sr_[0] = sample_rate_; + // hidden state buffer + h_.resize(size_hc_); + c_.resize(size_hc_); + + Reset(); + + + // InitRuntime + if (!InitRuntime()) { + fastdeploy::FDERROR << "Failed to initialize fastdeploy backend." + << std::endl; + return false; + } + + initialized_ = true; + + + fastdeploy::FDINFO << "init done."; + return true; +} + +bool Vad::ForwardChunk(std::vector& chunk) { + // last chunk may not be window_size_samples_ + input_node_dims_.back() = chunk.size(); + assert(window_size_samples_ >= chunk.size()); + current_chunk_size_ = chunk.size(); + + inputTensors_[0].name = "input"; + inputTensors_[0].SetExternalData( + input_node_dims_, fastdeploy::FDDataType::FP32, chunk.data()); + inputTensors_[1].name = "sr"; + inputTensors_[1].SetExternalData( + sr_node_dims_, fastdeploy::FDDataType::INT64, sr_.data()); + inputTensors_[2].name = "h"; + inputTensors_[2].SetExternalData( + hc_node_dims_, fastdeploy::FDDataType::FP32, h_.data()); + inputTensors_[3].name = "c"; + inputTensors_[3].SetExternalData( + hc_node_dims_, fastdeploy::FDDataType::FP32, c_.data()); + + if (!Infer(inputTensors_, &outputTensors_)) { + return false; + } + + // Push forward sample index + current_sample_ += current_chunk_size_; + return true; +} + +const Vad::State& Vad::Postprocess() { + // update prob, h, c + outputProb_ = *(float*)outputTensors_[0].Data(); + auto* hn = static_cast(outputTensors_[1].MutableData()); + std::memcpy(h_.data(), hn, h_.size() * sizeof(float)); + auto* cn = static_cast(outputTensors_[2].MutableData()); + std::memcpy(c_.data(), cn, c_.size() * sizeof(float)); + + if (outputProb_ < threshold_ && !triggerd_) { + // 1. Silence +#ifdef PPS_DEBUG + DLOG(INFO) << "{ silence: " << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; +#endif + states_.emplace_back(Vad::State::SIL); + } else if (outputProb_ >= threshold_ && !triggerd_) { + // 2. Start + triggerd_ = true; + speech_start_ = + current_sample_ - current_chunk_size_ - speech_pad_left_samples_; + speech_start_ = std::max(int(speech_start_), 0); + float start_sec = 1.0 * speech_start_ / sample_rate_; + speechStart_.emplace_back(start_sec); +#ifdef PPS_DEBUG + DLOG(INFO) << "{ speech start: " << start_sec + << " s; prob: " << outputProb_ << " }"; +#endif + states_.emplace_back(Vad::State::START); + } else if (outputProb_ >= threshold_ - beam_ && triggerd_) { + // 3. Continue + + if (temp_end_ != 0) { + // speech prob relaxation, speech continues again +#ifdef PPS_DEBUG + DLOG(INFO) + << "{ speech fake end(sil < min_silence_ms) to continue: " + << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; +#endif + temp_end_ = 0; + } else { + // speech prob relaxation, keep tracking speech +#ifdef PPS_DEBUG + DLOG(INFO) << "{ speech continue: " + << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; +#endif + } + + states_.emplace_back(Vad::State::SPEECH); + } else if (outputProb_ < threshold_ - beam_ && triggerd_) { + // 4. End + if (temp_end_ == 0) { + temp_end_ = current_sample_; + } + + // check possible speech end + if (current_sample_ - temp_end_ < min_silence_samples_) { + // a. silence < min_slience_samples, continue speaking +#ifdef PPS_DEBUG + DLOG(INFO) << "{ speech fake end(sil < min_silence_ms): " + << 1.0 * current_sample_ / sample_rate_ + << " s; prob: " << outputProb_ << " }"; +#endif + states_.emplace_back(Vad::State::SIL); + } else { + // b. silence >= min_slience_samples, end speaking + speech_end_ = current_sample_ + speech_pad_right_samples_; + temp_end_ = 0; + triggerd_ = false; + auto end_sec = 1.0 * speech_end_ / sample_rate_; + speechEnd_.emplace_back(end_sec); +#ifdef PPS_DEBUG + DLOG(INFO) << "{ speech end: " << end_sec + << " s; prob: " << outputProb_ << " }"; +#endif + states_.emplace_back(Vad::State::END); + } + } + + return states_.back(); +} + +std::string Vad::ConvertTime(float time_s) const{ + float seconds_tmp, minutes_tmp, hours_tmp; + float seconds; + int minutes, hours; + + // 计算小时 + hours_tmp = time_s / 60 / 60; // 1 + hours = (int)hours_tmp; + + // 计算分钟 + minutes_tmp = time_s / 60; + if (minutes_tmp >= 60) { + minutes = minutes_tmp - 60 * (double)hours; + } + else { + minutes = minutes_tmp; + } + + // 计算秒数 + seconds_tmp = (60 * 60 * hours) + (60 * minutes); + seconds = time_s - seconds_tmp; + + // 输出格式 + std::stringstream ss; + ss << hours << ":" << minutes << ":" << seconds; + + return ss.str(); +} + +int Vad::GetResult(char* result, int max_len, + float removeThreshold, + float expandHeadThreshold, + float expandTailThreshold, + float mergeThreshold) const { + float audioLength = 1.0 * current_sample_ / sample_rate_; + if (speechStart_.empty() && speechEnd_.empty()) { + return {}; + } + if (speechEnd_.size() != speechStart_.size()) { + // set the audio length as the last end + speechEnd_.emplace_back(audioLength); + } + + std::string json = "["; + + for (int i = 0; i < speechStart_.size(); ++i) { + json += "{\"s\":\"" + ConvertTime(speechStart_[i]) + "\",\"e\":\"" + ConvertTime(speechEnd_[i]) + "\"},"; + } + json.pop_back(); + json += "]"; + + if(result != NULL){ + snprintf(result, max_len, "%s", json.c_str()); + } else { + DLOG(INFO) << "result is NULL"; + } + return 0; +} + +std::ostream& operator<<(std::ostream& os, const Vad::State& s) { + switch (s) { + case Vad::State::SIL: + os << "[SIL]"; + break; + case Vad::State::START: + os << "[STA]"; + break; + case Vad::State::SPEECH: + os << "[SPE]"; + break; + case Vad::State::END: + os << "[END]"; + break; + default: + // illegal state + os << "[ILL]"; + break; + } + return os; +} + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/vad/nnet/vad.h b/runtime/engine/vad/nnet/vad.h new file mode 100644 index 00000000..31db78d2 --- /dev/null +++ b/runtime/engine/vad/nnet/vad.h @@ -0,0 +1,157 @@ +// Copyright (c) 2023 Chen Qianhe Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/runtime.h" +#include "vad/frontend/wav.h" + +namespace ppspeech { + +struct VadNnetConf { + // wav + int sr; + int frame_ms; + float threshold; + float beam; + int min_silence_duration_ms; + int speech_pad_left_ms; + int speech_pad_right_ms; + + // model + std::string model_file_path; + std::string param_file_path; + std::string dict_file_path; + int num_cpu_thread; // 1 thred + std::string backend; // ort,lite, etc. +}; + +class Vad : public fastdeploy::FastDeployModel { + public: + enum class State { ILLEGAL = 0, SIL, START, SPEECH, END }; + friend std::ostream& operator<<(std::ostream& os, const Vad::State& s); + + Vad(const std::string& model_file, + const fastdeploy::RuntimeOption& custom_option = + fastdeploy::RuntimeOption()); + + virtual ~Vad() {} + + void Init(); + + void Reset(); + + void SetConfig(const int& sr, + const int& frame_ms, + const float& threshold, + const float& beam, + const int& min_silence_duration_ms, + const int& speech_pad_left_ms, + const int& speech_pad_right_ms); + void SetConfig(const VadNnetConf conf); + + bool ForwardChunk(std::vector& chunk); + + const State& Postprocess(); + + int GetResult(char* result, int max_len, + float removeThreshold = 0.0, + float expandHeadThreshold = 0.0, + float expandTailThreshold = 0, + float mergeThreshold = 0.0) const; + + const std::vector GetStates() const { return states_; } + + int SampleRate() const { return sample_rate_; } + + int FrameMs() const { return frame_ms_; } + int64_t WindowSizeSamples() const { return window_size_samples_; } + + float Threshold() const { return threshold_; } + + int MinSilenceDurationMs() const { + return min_silence_samples_ / sample_rate_; + } + int SpeechPadLeftMs() const { + return speech_pad_left_samples_ / sample_rate_; + } + int SpeechPadRightMs() const { + return speech_pad_right_samples_ / sample_rate_; + } + + int MinSilenceSamples() const { return min_silence_samples_; } + int SpeechPadLeftSamples() const { return speech_pad_left_samples_; } + int SpeechPadRightSamples() const { return speech_pad_right_samples_; } + + std::string ModelName() const override; + + private: + bool Initialize(); + std::string ConvertTime(float time_s) const; + + private: + std::mutex init_lock_; + bool initialized_{false}; + + // input and output + std::vector inputTensors_; + std::vector outputTensors_; + + // model states + bool triggerd_ = false; + unsigned int speech_start_ = 0; + unsigned int speech_end_ = 0; + unsigned int temp_end_ = 0; + unsigned int current_sample_ = 0; + unsigned int current_chunk_size_ = 0; + // MAX 4294967295 samples / 8sample per ms / 1000 / 60 = 8947 minutes + float outputProb_; + + std::vector speechStart_; + mutable std::vector speechEnd_; + + std::vector states_; + + /* ======================================================================== + */ + int sample_rate_ = 16000; + int frame_ms_ = 32; // 32, 64, 96 for 16k + float threshold_ = 0.5f; + float beam_ = 0.15f; + + int64_t window_size_samples_; // support 256 512 768 for 8k; 512 1024 1536 + // for 16k. + int sr_per_ms_; // support 8 or 16 + int min_silence_samples_; // sr_per_ms_ * frame_ms_ + int speech_pad_left_samples_{0}; // usually 250ms + int speech_pad_right_samples_{0}; // usually 0 + + /* ======================================================================== + */ + std::vector sr_; + const size_t size_hc_ = 2 * 1 * 64; // It's FIXED. + std::vector h_; + std::vector c_; + + std::vector input_node_dims_; + const std::vector sr_node_dims_ = {1}; + const std::vector hc_node_dims_ = {2, 1, 64}; +}; + +} // namespace ppspeech \ No newline at end of file diff --git a/runtime/engine/vad/nnet/vad_nnet_main.cc b/runtime/engine/vad/nnet/vad_nnet_main.cc new file mode 100644 index 00000000..f3079b42 --- /dev/null +++ b/runtime/engine/vad/nnet/vad_nnet_main.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "common/base/common.h" +#include "vad/nnet/vad.h" + +int main(int argc, char* argv[]) { + if (argc < 3) { + std::cout << "Usage: vad_nnet_main path/to/model path/to/audio " + "run_option, " + "e.g ./vad_nnet_main silero_vad.onnx sample.wav" + << std::endl; + return -1; + } + + std::string model_file = argv[1]; + std::string audio_file = argv[2]; + + int sr = 16000; + ppspeech::Vad vad(model_file); + // custom config, but must be set before init + vad.SetConfig(sr, 32, 0.5f, 0.15, 200, 0, 0); + vad.Init(); + + std::vector inputWav; // [0, 1] + wav::WavReader wav_reader = wav::WavReader(audio_file); + assert(wav_reader.sample_rate() == sr); + + + auto num_samples = wav_reader.num_samples(); + inputWav.resize(num_samples); + for (int i = 0; i < num_samples; i++) { + inputWav[i] = wav_reader.data()[i] / 32768; + } + + ppspeech::Timer timer; + int window_size_samples = vad.WindowSizeSamples(); + for (int64_t j = 0; j < num_samples; j += window_size_samples) { + auto start = j; + auto end = start + window_size_samples >= num_samples + ? num_samples + : start + window_size_samples; + auto current_chunk_size = end - start; + + std::vector r{&inputWav[0] + start, &inputWav[0] + end}; + assert(r.size() == static_cast(current_chunk_size)); + + if (!vad.ForwardChunk(r)) { + std::cerr << "Failed to inference while using model:" + << vad.ModelName() << "." << std::endl; + return false; + } + + ppspeech::Vad::State s = vad.Postprocess(); + std::cout << s << " "; + } + std::cout << std::endl; + + std::cout << "RTF=" << timer.Elapsed() / double(num_samples / sr) + << std::endl; + std::cout << "\b\b " << std::endl; + + vad.Reset(); + + return 0; +} diff --git a/speechx/examples/.gitignore b/runtime/examples/.gitignore similarity index 80% rename from speechx/examples/.gitignore rename to runtime/examples/.gitignore index b7075fa5..38290f34 100644 --- a/speechx/examples/.gitignore +++ b/runtime/examples/.gitignore @@ -1,2 +1,3 @@ *.ark +*.scp paddle_asr_model/ diff --git a/speechx/examples/README.md b/runtime/examples/README.md similarity index 100% rename from speechx/examples/README.md rename to runtime/examples/README.md diff --git a/runtime/examples/android/VadJni/.gitignore b/runtime/examples/android/VadJni/.gitignore new file mode 100644 index 00000000..aa724b77 --- /dev/null +++ b/runtime/examples/android/VadJni/.gitignore @@ -0,0 +1,15 @@ +*.iml +.gradle +/local.properties +/.idea/caches +/.idea/libraries +/.idea/modules.xml +/.idea/workspace.xml +/.idea/navEditor.xml +/.idea/assetWizardSettings.xml +.DS_Store +/build +/captures +.externalNativeBuild +.cxx +local.properties diff --git a/runtime/examples/android/VadJni/.idea/.gitignore b/runtime/examples/android/VadJni/.idea/.gitignore new file mode 100644 index 00000000..26d33521 --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/runtime/examples/android/VadJni/.idea/.name b/runtime/examples/android/VadJni/.idea/.name new file mode 100644 index 00000000..b5712d1e --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/.name @@ -0,0 +1 @@ +VadJni \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/compiler.xml b/runtime/examples/android/VadJni/.idea/compiler.xml new file mode 100644 index 00000000..fb7f4a8a --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/compiler.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/deploymentTargetDropDown.xml b/runtime/examples/android/VadJni/.idea/deploymentTargetDropDown.xml new file mode 100644 index 00000000..f26362be --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/deploymentTargetDropDown.xml @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/gradle.xml b/runtime/examples/android/VadJni/.idea/gradle.xml new file mode 100644 index 00000000..a2d7c213 --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/gradle.xml @@ -0,0 +1,19 @@ + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/misc.xml b/runtime/examples/android/VadJni/.idea/misc.xml new file mode 100644 index 00000000..bdd92780 --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/misc.xml @@ -0,0 +1,10 @@ + + + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/.idea/vcs.xml b/runtime/examples/android/VadJni/.idea/vcs.xml new file mode 100644 index 00000000..4fce1d86 --- /dev/null +++ b/runtime/examples/android/VadJni/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/.gitignore b/runtime/examples/android/VadJni/app/.gitignore new file mode 100644 index 00000000..44399f1d --- /dev/null +++ b/runtime/examples/android/VadJni/app/.gitignore @@ -0,0 +1,2 @@ +/build +/cache diff --git a/runtime/examples/android/VadJni/app/build.gradle b/runtime/examples/android/VadJni/app/build.gradle new file mode 100644 index 00000000..f2025a21 --- /dev/null +++ b/runtime/examples/android/VadJni/app/build.gradle @@ -0,0 +1,129 @@ +plugins { + id 'com.android.application' +} + +android { + namespace 'com.baidu.paddlespeech.vadjni' + compileSdk 33 + ndkVersion '23.1.7779620' + + defaultConfig { + applicationId "com.baidu.paddlespeech.vadjni" + minSdk 21 + targetSdk 33 + versionCode 1 + versionName "1.0" + + testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + + externalNativeBuild { + cmake { + arguments '-DANDROID_PLATFORM=android-21', '-DANDROID_STL=c++_shared', "-DANDROID_TOOLCHAIN=clang" + abiFilters 'arm64-v8a' + cppFlags "-std=c++11" + } + } + } + + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' + } + } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } + externalNativeBuild { + cmake { + path file('src/main/cpp/CMakeLists.txt') + version '3.22.1' + } + } + buildFeatures { + viewBinding true + } + sourceSets { + main { + jniLibs.srcDirs = ['libs'] + } + } +} + +dependencies { + // Dependency on local binaries + implementation fileTree(dir: 'libs', include: ['*.jar']) + // Dependency on a remote binary + implementation 'androidx.appcompat:appcompat:1.4.1' + implementation 'com.google.android.material:material:1.5.0' + implementation 'androidx.constraintlayout:constraintlayout:2.1.3' + testImplementation 'junit:junit:4.13.2' + androidTestImplementation 'androidx.test.ext:junit:1.1.3' + androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0' +} + +def CXX_LIB = [ +// [ +// 'src' : 'https://bj.bcebos.com/fastdeploy/dev/android/fastdeploy-android-with-text-0.0.0-shared.tgz', +// 'dest': 'libs', +// 'name': 'fastdeploy-android-latest-shared-dev' +// ] +] + +task downloadAndExtractLibs(type: DefaultTask) { + doFirst { + println "[INFO] Downloading and extracting fastdeploy android c++ lib ..." + } + doLast { + String cachePath = "cache" + if (!file("${cachePath}").exists()) { + mkdir "${cachePath}" + } + + CXX_LIB.eachWithIndex { lib, index -> + + String[] libPaths = lib.src.split("/") + String sdkName = lib.name + String libName = libPaths[libPaths.length - 1] + libName = libName.substring(0, libName.indexOf("tgz") - 1) + String cacheName = cachePath + "/" + "${libName}.tgz" + + String libDir = lib.dest + "/" + libName + String sdkDir = lib.dest + "/" + sdkName + + boolean copyFiles = false + if (!file("${sdkDir}").exists()) { + // Download lib and rename to sdk name later. + if (!file("${cacheName}").exists()) { + println "[INFO] Downloading ${lib.src} -> ${cacheName}" + ant.get(src: lib.src, dest: file("${cacheName}")) + } + copyFiles = true + } + + if (copyFiles) { + println "[INFO] Taring ${cacheName} -> ${libDir}" + copy { from(tarTree("${cacheName}")) into("${lib.dest}") } + if (!libName.equals(sdkName)) { + if (file("${sdkDir}").exists()) { + delete("${sdkDir}") + println "[INFO] Remove old ${sdkDir}" + } + mkdir "${sdkDir}" + println "[INFO] Coping ${libDir} -> ${sdkDir}" + copy { from("${libDir}") into("${sdkDir}") } + delete("${libDir}") + println "[INFO] Removed ${libDir}" + println "[INFO] Update ${sdkDir} done!" + } + } else { + println "[INFO] ${sdkDir} already exists!" + println "[WARN] Please delete ${cacheName} and ${sdkDir} " + + "if you want to UPDATE ${sdkName} c++ lib. Then, rebuild this sdk." + } + } + } +} + +preBuild.dependsOn downloadAndExtractLibs \ No newline at end of file diff --git a/speechx/speechx/third_party/CMakeLists.txt b/runtime/examples/android/VadJni/app/libs/.gitkeep similarity index 100% rename from speechx/speechx/third_party/CMakeLists.txt rename to runtime/examples/android/VadJni/app/libs/.gitkeep diff --git a/runtime/examples/android/VadJni/app/proguard-rules.pro b/runtime/examples/android/VadJni/app/proguard-rules.pro new file mode 100644 index 00000000..481bb434 --- /dev/null +++ b/runtime/examples/android/VadJni/app/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/androidTest/java/com/baidu/paddlespeech/vadjni/ExampleInstrumentedTest.java b/runtime/examples/android/VadJni/app/src/androidTest/java/com/baidu/paddlespeech/vadjni/ExampleInstrumentedTest.java new file mode 100644 index 00000000..5c02120b --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/androidTest/java/com/baidu/paddlespeech/vadjni/ExampleInstrumentedTest.java @@ -0,0 +1,26 @@ +package com.baidu.paddlespeech.vadjni; + +import android.content.Context; + +import androidx.test.platform.app.InstrumentationRegistry; +import androidx.test.ext.junit.runners.AndroidJUnit4; + +import org.junit.Test; +import org.junit.runner.RunWith; + +import static org.junit.Assert.*; + +/** + * Instrumented test, which will execute on an Android device. + * + * @see Testing documentation + */ +@RunWith(AndroidJUnit4.class) +public class ExampleInstrumentedTest { + @Test + public void useAppContext() { + // Context of the app under test. + Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext(); + assertEquals("com.baidu.paddlespeech.vadjni", appContext.getPackageName()); + } +} \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/AndroidManifest.xml b/runtime/examples/android/VadJni/app/src/main/AndroidManifest.xml new file mode 100644 index 00000000..d8076922 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/AndroidManifest.xml @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/assets/.gitkeep b/runtime/examples/android/VadJni/app/src/main/assets/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/runtime/examples/android/VadJni/app/src/main/cpp/CMakeLists.txt b/runtime/examples/android/VadJni/app/src/main/cpp/CMakeLists.txt new file mode 100644 index 00000000..5eaa053b --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/cpp/CMakeLists.txt @@ -0,0 +1,59 @@ +# For more information about using CMake with Android Studio, read the +# documentation: https://d.android.com/studio/projects/add-native-code.html + +# Sets the minimum version of CMake required to build the native library. + +cmake_minimum_required(VERSION 3.22.1) + +# Declares and names the project. + +project("vadjni") + + +set(PPS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../libs/${ANDROID_ABI}) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +# Creates and names a library, sets it as either STATIC +# or SHARED, and provides the relative paths to its source code. +# You can define multiple libraries, and CMake builds them for you. +# Gradle automatically packages shared libraries with your APK. + +add_library( # Sets the name of the library. + vadjni + + # Sets the library as a shared library. + SHARED + + # Provides a relative path to your source file(s). + native-lib.cpp) + +# Searches for a specified prebuilt library and stores the path as a +# variable. Because CMake includes system libraries in the search path by +# default, you only need to specify the name of the public NDK library +# you want to add. CMake verifies that the library exists before +# completing its build. + +find_library( # Sets the name of the path variable. + log-lib + + # Specifies the name of the NDK library that + # you want CMake to locate. + log) + +# Specifies libraries CMake should link to your target library. You +# can link multiple libraries, such as libraries you define in this +# build script, prebuilt third-party libraries, or system libraries. + +message(STATUS "PPS_DIR=${PPS_DIR}") +target_link_libraries( # Specifies the target library. + vadjni + ${PPS_DIR}/libfastdeploy.so + ${PPS_DIR}/libonnxruntime.so + ${PPS_DIR}/libgflags_nothreads.a + ${PPS_DIR}/libbase.a + ${PPS_DIR}/libpps_vad.a + ${PPS_DIR}/libpps_vad_interface.a + # Links the target library to the log library + # included in the NDK. + ${log-lib}) \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/cpp/native-lib.cpp b/runtime/examples/android/VadJni/app/src/main/cpp/native-lib.cpp new file mode 100644 index 00000000..e80ac2e4 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/cpp/native-lib.cpp @@ -0,0 +1,57 @@ + +#include +#include "vad_interface.h" +#include + +extern "C" +JNIEXPORT jstring JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_stringFromJNI( + JNIEnv* env, + jobject /* this */) { + std::string hello = "Hello from C++"; + return env->NewStringUTF(hello.c_str()); +} + +extern "C" +JNIEXPORT jlong JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_createInstance( + JNIEnv* env, + jobject thiz, + jstring conf_path){ + const char* path = env->GetStringUTFChars(conf_path, JNI_FALSE); + PPSHandle_t handle = PPSVadCreateInstance(path); + + return (jlong)(handle); + return 0; +} + + +extern "C" +JNIEXPORT jint JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_destroyInstance(JNIEnv *env, jobject thiz, + jlong instance) { + PPSHandle_t handle = (PPSHandle_t)(instance); + return (jint)PPSVadDestroyInstance(handle); +} +extern "C" +JNIEXPORT jint JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_reset(JNIEnv *env, jobject thiz, jlong instance) { + PPSHandle_t handle = (PPSHandle_t)(instance); + return (jint)PPSVadReset(handle); +} +extern "C" +JNIEXPORT jint JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_chunkSizeSamples(JNIEnv *env, jobject thiz, + jlong instance) { + PPSHandle_t handle = (PPSHandle_t)(instance); + return (jint)PPSVadChunkSizeSamples(handle); +} +extern "C" +JNIEXPORT jint JNICALL +Java_com_baidu_paddlespeech_vadjni_MainActivity_feedForward(JNIEnv *env, jobject thiz, + jlong instance, jfloatArray chunk) { + PPSHandle_t handle = (PPSHandle_t)(instance); + jsize num_elms = env->GetArrayLength(chunk); + jfloat* chunk_ptr = env->GetFloatArrayElements(chunk, JNI_FALSE); + return (jint)PPSVadFeedForward(handle, (float*)chunk_ptr, (int)num_elms); +} \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/cpp/vad_interface.h b/runtime/examples/android/VadJni/app/src/main/cpp/vad_interface.h new file mode 100644 index 00000000..5d7ca709 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/cpp/vad_interface.h @@ -0,0 +1,46 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void* PPSHandle_t; + +typedef enum { + PPS_VAD_ILLEGAL = 0, // error + PPS_VAD_SIL, // silence + PPS_VAD_START, // start speech + PPS_VAD_SPEECH, // in speech + PPS_VAD_END, // end speech + PPS_VAD_NUMSTATES, // number of states +} PPSVadState_t; + +PPSHandle_t PPSVadCreateInstance(const char* conf_path); + +int PPSVadDestroyInstance(PPSHandle_t instance); + +int PPSVadReset(PPSHandle_t instance); + +int PPSVadChunkSizeSamples(PPSHandle_t instance); + +PPSVadState_t PPSVadFeedForward(PPSHandle_t instance, + float* chunk, + int num_element); + +#ifdef __cplusplus +} +#endif // __cplusplus \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/java/com/baidu/paddlespeech/vadjni/MainActivity.java b/runtime/examples/android/VadJni/app/src/main/java/com/baidu/paddlespeech/vadjni/MainActivity.java new file mode 100644 index 00000000..3b463280 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/java/com/baidu/paddlespeech/vadjni/MainActivity.java @@ -0,0 +1,50 @@ +package com.baidu.paddlespeech.vadjni; + +import androidx.appcompat.app.AppCompatActivity; + +import android.os.Bundle; +import android.widget.Button; +import android.widget.TextView; + +import com.baidu.paddlespeech.vadjni.databinding.ActivityMainBinding; + +public class MainActivity extends AppCompatActivity { + + // Used to load the 'vadjni' library on application startup. + static { + System.loadLibrary("vadjni"); + } + + private ActivityMainBinding binding; + private long instance; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + + binding = ActivityMainBinding.inflate(getLayoutInflater()); + setContentView(binding.getRoot()); + + // Example of a call to a native method + TextView tv = binding.sampleText; + tv.setText(stringFromJNI()); + + Button lw = binding.loadWav; + } + + /** + * A native method that is implemented by the 'vadjni' native library, + * which is packaged with this application. + */ + public native String stringFromJNI(); + + public static native long createInstance(String config_path); + + public static native int destroyInstance(long instance); + + public static native int reset(long instance); + + public static native int chunkSizeSamples(long instance); + + public static native int feedForward(long instance, float[] chunk); +} \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/res/drawable-v24/ic_launcher_foreground.xml b/runtime/examples/android/VadJni/app/src/main/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 00000000..2b068d11 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,30 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/runtime/examples/android/VadJni/app/src/main/res/drawable/ic_launcher_background.xml b/runtime/examples/android/VadJni/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 00000000..07d5da9c --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/runtime/examples/android/VadJni/app/src/main/res/layout/activity_main.xml b/runtime/examples/android/VadJni/app/src/main/res/layout/activity_main.xml new file mode 100644 index 00000000..c9938516 --- /dev/null +++ b/runtime/examples/android/VadJni/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,28 @@ + + + + + +